diff --git a/lib/std/crypto.zig b/lib/std/crypto.zig index 7b167a467a..a444b41cc3 100644 --- a/lib/std/crypto.zig +++ b/lib/std/crypto.zig @@ -220,6 +220,7 @@ pub const errors = @import("crypto/errors.zig"); pub const tls = @import("crypto/tls.zig"); pub const Certificate = @import("crypto/Certificate.zig"); +pub const asn1 = @import("crypto/asn1.zig"); /// Side-channels mitigations. pub const SideChannelsMitigations = enum { @@ -334,6 +335,7 @@ test { _ = errors; _ = tls; _ = Certificate; + _ = asn1; } test "CSPRNG" { diff --git a/lib/std/crypto/asn1.zig b/lib/std/crypto/asn1.zig new file mode 100644 index 0000000000..b9c0a9e109 --- /dev/null +++ b/lib/std/crypto/asn1.zig @@ -0,0 +1,359 @@ +//! ASN.1 types for public consumption. +const std = @import("std"); +pub const der = @import("./asn1/der.zig"); +pub const Oid = @import("./asn1/Oid.zig"); + +pub const Index = u32; + +pub const Tag = struct { + number: Number, + /// Whether this ASN.1 type contains other ASN.1 types. + constructed: bool, + class: Class, + + /// These values apply to class == .universal. + pub const Number = enum(u16) { + // 0 is reserved by spec + boolean = 1, + integer = 2, + bitstring = 3, + octetstring = 4, + null = 5, + oid = 6, + object_descriptor = 7, + real = 9, + enumerated = 10, + embedded = 11, + string_utf8 = 12, + oid_relative = 13, + time = 14, + // 15 is reserved to mean that the tag is >= 32 + sequence = 16, + /// Elements may appear in any order. + sequence_of = 17, + string_numeric = 18, + string_printable = 19, + string_teletex = 20, + string_videotex = 21, + string_ia5 = 22, + utc_time = 23, + generalized_time = 24, + string_graphic = 25, + string_visible = 26, + string_general = 27, + string_universal = 28, + string_char = 29, + string_bmp = 30, + date = 31, + time_of_day = 32, + date_time = 33, + duration = 34, + /// IRI = Internationalized Resource Identifier + oid_iri = 35, + oid_iri_relative = 36, + _, + }; + + pub const Class = enum(u2) { + universal, + application, + context_specific, + private, + }; + + pub fn init(number: Tag.Number, constructed: bool, class: Tag.Class) Tag { + return .{ .number = number, .constructed = constructed, .class = class }; + } + + pub fn universal(number: Tag.Number, constructed: bool) Tag { + return .{ .number = number, .constructed = constructed, .class = .universal }; + } + + pub fn decode(reader: anytype) !Tag { + const tag1: FirstTag = @bitCast(try reader.readByte()); + var number: u14 = tag1.number; + + if (tag1.number == 15) { + const tag2: NextTag = @bitCast(try reader.readByte()); + number = tag2.number; + if (tag2.continues) { + const tag3: NextTag = @bitCast(try reader.readByte()); + number = (number << 7) + tag3.number; + if (tag3.continues) return error.InvalidLength; + } + } + + return Tag{ + .number = @enumFromInt(number), + .constructed = tag1.constructed, + .class = tag1.class, + }; + } + + pub fn encode(self: Tag, writer: anytype) @TypeOf(writer).Error!void { + var tag1 = FirstTag{ + .number = undefined, + .constructed = self.constructed, + .class = self.class, + }; + + var buffer: [3]u8 = undefined; + var stream = std.io.fixedBufferStream(&buffer); + var writer2 = stream.writer(); + + switch (@intFromEnum(self.number)) { + 0...std.math.maxInt(u5) => |n| { + tag1.number = @intCast(n); + writer2.writeByte(@bitCast(tag1)) catch unreachable; + }, + std.math.maxInt(u5) + 1...std.math.maxInt(u7) => |n| { + tag1.number = 15; + const tag2 = NextTag{ .number = @intCast(n), .continues = false }; + writer2.writeByte(@bitCast(tag1)) catch unreachable; + writer2.writeByte(@bitCast(tag2)) catch unreachable; + }, + else => |n| { + tag1.number = 15; + const tag2 = NextTag{ .number = @intCast(n >> 7), .continues = true }; + const tag3 = NextTag{ .number = @truncate(n), .continues = false }; + writer2.writeByte(@bitCast(tag1)) catch unreachable; + writer2.writeByte(@bitCast(tag2)) catch unreachable; + writer2.writeByte(@bitCast(tag3)) catch unreachable; + }, + } + + _ = try writer.write(stream.getWritten()); + } + + const FirstTag = packed struct(u8) { number: u5, constructed: bool, class: Tag.Class }; + const NextTag = packed struct(u8) { number: u7, continues: bool }; + + pub fn toExpected(self: Tag) ExpectedTag { + return ExpectedTag{ + .number = self.number, + .constructed = self.constructed, + .class = self.class, + }; + } + + pub fn fromZig(comptime T: type) Tag { + switch (@typeInfo(T)) { + .Struct, .Enum, .Union => { + if (@hasDecl(T, "asn1_tag")) return T.asn1_tag; + }, + else => {}, + } + + switch (@typeInfo(T)) { + .Struct, .Union => return universal(.sequence, true), + .Bool => return universal(.boolean, false), + .Int => return universal(.integer, false), + .Enum => |e| { + if (@hasDecl(T, "oids")) return Oid.asn1_tag; + return universal(if (e.is_exhaustive) .enumerated else .integer, false); + }, + .Optional => |o| return fromZig(o.child), + .Null => return universal(.null, false), + else => @compileError("cannot map Zig type to asn1_tag " ++ @typeName(T)), + } + } +}; + +test Tag { + const buf = [_]u8{0xa3}; + var stream = std.io.fixedBufferStream(&buf); + const t = Tag.decode(stream.reader()); + try std.testing.expectEqual(Tag.init(@enumFromInt(3), true, .context_specific), t); +} + +/// A decoded view. +pub const Element = struct { + tag: Tag, + slice: Slice, + + pub const Slice = struct { + start: Index, + end: Index, + + pub fn len(self: Slice) Index { + return self.end - self.start; + } + + pub fn view(self: Slice, bytes: []const u8) []const u8 { + return bytes[self.start..self.end]; + } + }; + + pub const DecodeError = error{ InvalidLength, EndOfStream }; + + /// Safely decode a DER/BER/CER element at `index`: + /// - Ensures length uses shortest form + /// - Ensures length is within `bytes` + /// - Ensures length is less than `std.math.maxInt(Index)` + pub fn decode(bytes: []const u8, index: Index) DecodeError!Element { + var stream = std.io.fixedBufferStream(bytes[index..]); + var reader = stream.reader(); + + const tag = try Tag.decode(reader); + const size_or_len_size = try reader.readByte(); + + var start = index + 2; + var end = start + size_or_len_size; + // short form between 0-127 + if (size_or_len_size < 128) { + if (end > bytes.len) return error.InvalidLength; + } else { + // long form between 0 and std.math.maxInt(u1024) + const len_size: u7 = @truncate(size_or_len_size); + start += len_size; + if (len_size > @sizeOf(Index)) return error.InvalidLength; + + const len = try reader.readVarInt(Index, .big, len_size); + if (len < 128) return error.InvalidLength; // should have used short form + + end = std.math.add(Index, start, len) catch return error.InvalidLength; + if (end > bytes.len) return error.InvalidLength; + } + + return Element{ .tag = tag, .slice = Slice{ .start = start, .end = end } }; + } +}; + +test Element { + const short_form = [_]u8{ 0x30, 0x03, 0x02, 0x01, 0x09 }; + try std.testing.expectEqual(Element{ + .tag = Tag.universal(.sequence, true), + .slice = Element.Slice{ .start = 2, .end = short_form.len }, + }, Element.decode(&short_form, 0)); + + const long_form = [_]u8{ 0x30, 129, 129 } ++ [_]u8{0} ** 129; + try std.testing.expectEqual(Element{ + .tag = Tag.universal(.sequence, true), + .slice = Element.Slice{ .start = 3, .end = long_form.len }, + }, Element.decode(&long_form, 0)); +} + +/// For decoding. +pub const ExpectedTag = struct { + number: ?Tag.Number = null, + constructed: ?bool = null, + class: ?Tag.Class = null, + + pub fn init(number: ?Tag.Number, constructed: ?bool, class: ?Tag.Class) ExpectedTag { + return .{ .number = number, .constructed = constructed, .class = class }; + } + + pub fn primitive(number: ?Tag.Number) ExpectedTag { + return .{ .number = number, .constructed = false, .class = .universal }; + } + + pub fn match(self: ExpectedTag, tag: Tag) bool { + if (self.number) |e| { + if (tag.number != e) return false; + } + if (self.constructed) |e| { + if (tag.constructed != e) return false; + } + if (self.class) |e| { + if (tag.class != e) return false; + } + return true; + } +}; + +pub const FieldTag = struct { + number: std.meta.Tag(Tag.Number), + class: Tag.Class, + explicit: bool = true, + + pub fn explicit(number: std.meta.Tag(Tag.Number), class: Tag.Class) FieldTag { + return FieldTag{ .number = number, .class = class, .explicit = true }; + } + + pub fn implicit(number: std.meta.Tag(Tag.Number), class: Tag.Class) FieldTag { + return FieldTag{ .number = number, .class = class, .explicit = false }; + } + + pub fn fromContainer(comptime Container: type, comptime field_name: []const u8) ?FieldTag { + if (@hasDecl(Container, "asn1_tags") and @hasField(@TypeOf(Container.asn1_tags), field_name)) { + return @field(Container.asn1_tags, field_name); + } + + return null; + } + + pub fn toTag(self: FieldTag) Tag { + return Tag.init(@enumFromInt(self.number), self.explicit, self.class); + } +}; + +pub const BitString = struct { + /// Number of bits in rightmost byte that are unused. + right_padding: u3 = 0, + bytes: []const u8, + + pub fn bitLen(self: BitString) usize { + return self.bytes.len * 8 - self.right_padding; + } + + const asn1_tag = Tag.universal(.bitstring, false); + + pub fn decodeDer(decoder: *der.Decoder) !BitString { + const ele = try decoder.element(asn1_tag.toExpected()); + const bytes = decoder.view(ele); + + if (bytes.len < 1) return error.InvalidBitString; + const padding = bytes[0]; + if (padding >= 8) return error.InvalidBitString; + const right_padding: u3 = @intCast(padding); + + // DER requires that unused bits be zero. + if (@ctz(bytes[bytes.len - 1]) < right_padding) return error.InvalidBitString; + + return BitString{ .bytes = bytes[1..], .right_padding = right_padding }; + } + + pub fn encodeDer(self: BitString, encoder: *der.Encoder) !void { + try encoder.writer().writeAll(self.bytes); + try encoder.writer().writeByte(self.right_padding); + try encoder.length(self.bytes.len + 1); + try encoder.tag(asn1_tag); + } +}; + +pub fn Opaque(comptime tag: Tag) type { + return struct { + bytes: []const u8, + + pub fn decodeDer(decoder: *der.Decoder) !@This() { + const ele = try decoder.element(tag.toExpected()); + if (tag.constructed) decoder.index = ele.slice.end; + return .{ .bytes = decoder.view(ele) }; + } + + pub fn encodeDer(self: @This(), encoder: *der.Encoder) !void { + try encoder.tagBytes(tag, self.bytes); + } + }; +} + +/// Use sparingly. +pub const Any = struct { + tag: Tag, + bytes: []const u8, + + pub fn decodeDer(decoder: *der.Decoder) !@This() { + const ele = try decoder.element(ExpectedTag{}); + return .{ .tag = ele.tag, .bytes = decoder.view(ele) }; + } + + pub fn encodeDer(self: @This(), encoder: *der.Encoder) !void { + try encoder.tagBytes(self.tag, self.bytes); + } +}; + +test { + _ = der; + _ = Oid; + _ = @import("asn1/test.zig"); +} diff --git a/lib/std/crypto/asn1/Oid.zig b/lib/std/crypto/asn1/Oid.zig new file mode 100644 index 0000000000..897b4050bf --- /dev/null +++ b/lib/std/crypto/asn1/Oid.zig @@ -0,0 +1,210 @@ +//! Globally unique hierarchical identifier made of a sequence of integers. +//! +//! Commonly used to identify standards, algorithms, certificate extensions, +//! organizations, or policy documents. +encoded: []const u8, + +pub const InitError = std.fmt.ParseIntError || error{MissingPrefix} || std.io.FixedBufferStream(u8).WriteError; + +pub fn fromDot(dot_notation: []const u8, out: []u8) InitError!Oid { + var split = std.mem.splitScalar(u8, dot_notation, '.'); + const first_str = split.next() orelse return error.MissingPrefix; + const second_str = split.next() orelse return error.MissingPrefix; + + const first = try std.fmt.parseInt(u8, first_str, 10); + const second = try std.fmt.parseInt(u8, second_str, 10); + + var stream = std.io.fixedBufferStream(out); + var writer = stream.writer(); + + try writer.writeByte(first * 40 + second); + + var i: usize = 1; + while (split.next()) |s| { + var parsed = try std.fmt.parseUnsigned(Arc, s, 10); + const n_bytes = if (parsed == 0) 0 else std.math.log(Arc, encoding_base, parsed); + + for (0..n_bytes) |j| { + const place = std.math.pow(Arc, encoding_base, n_bytes - @as(Arc, @intCast(j))); + const digit: u8 = @intCast(@divFloor(parsed, place)); + + try writer.writeByte(digit | 0x80); + parsed -= digit * place; + + i += 1; + } + try writer.writeByte(@intCast(parsed)); + i += 1; + } + + return .{ .encoded = stream.getWritten() }; +} + +test fromDot { + var buf: [256]u8 = undefined; + for (test_cases) |t| { + const actual = try fromDot(t.dot_notation, &buf); + try std.testing.expectEqualSlices(u8, t.encoded, actual.encoded); + } +} + +pub fn toDot(self: Oid, writer: anytype) @TypeOf(writer).Error!void { + const encoded = self.encoded; + const first = @divTrunc(encoded[0], 40); + const second = encoded[0] - first * 40; + try writer.print("{d}.{d}", .{ first, second }); + + var i: usize = 1; + while (i != encoded.len) { + const n_bytes: usize = brk: { + var res: usize = 1; + var j: usize = i; + while (encoded[j] & 0x80 != 0) { + res += 1; + j += 1; + } + break :brk res; + }; + + var n: usize = 0; + for (0..n_bytes) |j| { + const place = std.math.pow(usize, encoding_base, n_bytes - j - 1); + n += place * (encoded[i] & 0b01111111); + i += 1; + } + try writer.print(".{d}", .{n}); + } +} + +test toDot { + var buf: [256]u8 = undefined; + + for (test_cases) |t| { + var stream = std.io.fixedBufferStream(&buf); + try toDot(Oid{ .encoded = t.encoded }, stream.writer()); + try std.testing.expectEqualStrings(t.dot_notation, stream.getWritten()); + } +} + +const TestCase = struct { + encoded: []const u8, + dot_notation: []const u8, + + pub fn init(comptime hex: []const u8, dot_notation: []const u8) TestCase { + return .{ .encoded = &hexToBytes(hex), .dot_notation = dot_notation }; + } +}; + +const test_cases = [_]TestCase{ + // https://learn.microsoft.com/en-us/windows/win32/seccertenroll/about-object-identifier + TestCase.init("2b0601040182371514", "1.3.6.1.4.1.311.21.20"), + // https://luca.ntop.org/Teaching/Appunti/asn1.html + TestCase.init("2a864886f70d", "1.2.840.113549"), + // https://www.sysadmins.lv/blog-en/how-to-encode-object-identifier-to-an-asn1-der-encoded-string.aspx + TestCase.init("2a868d20", "1.2.100000"), + TestCase.init("2a864886f70d01010b", "1.2.840.113549.1.1.11"), + TestCase.init("2b6570", "1.3.101.112"), +}; + +pub const asn1_tag = asn1.Tag.init(.oid, false, .universal); + +pub fn decodeDer(decoder: *der.Decoder) !Oid { + const ele = try decoder.element(asn1_tag.toExpected()); + return Oid{ .encoded = decoder.view(ele) }; +} + +pub fn encodeDer(self: Oid, encoder: *der.Encoder) !void { + try encoder.tagBytes(asn1_tag, self.encoded); +} + +fn encodedLen(dot_notation: []const u8) usize { + var buf: [256]u8 = undefined; + const oid = fromDot(dot_notation, &buf) catch unreachable; + return oid.encoded.len; +} + +/// Returns encoded bytes of OID. +fn encodeComptime(comptime dot_notation: []const u8) [encodedLen(dot_notation)]u8 { + @setEvalBranchQuota(4000); + comptime var buf: [256]u8 = undefined; + const oid = comptime fromDot(dot_notation, &buf) catch unreachable; + return oid.encoded[0..oid.encoded.len].*; +} + +test encodeComptime { + try std.testing.expectEqual( + hexToBytes("2b0601040182371514"), + comptime encodeComptime("1.3.6.1.4.1.311.21.20"), + ); +} + +pub fn fromDotComptime(comptime dot_notation: []const u8) Oid { + const tmp = comptime encodeComptime(dot_notation); + return Oid{ .encoded = &tmp }; +} + +/// Maps of: +/// - Oid -> enum +/// - Enum -> oid +pub fn StaticMap(comptime Enum: type) type { + const enum_info = @typeInfo(Enum).Enum; + const EnumToOid = std.EnumArray(Enum, []const u8); + const ReturnType = struct { + oid_to_enum: std.StaticStringMap(Enum), + enum_to_oid: EnumToOid, + + pub fn oidToEnum(self: @This(), encoded: []const u8) ?Enum { + return self.oid_to_enum.get(encoded); + } + + pub fn enumToOid(self: @This(), value: Enum) Oid { + const bytes = self.enum_to_oid.get(value); + return .{ .encoded = bytes }; + } + }; + + return struct { + pub fn initComptime(comptime key_pairs: anytype) ReturnType { + const struct_info = @typeInfo(@TypeOf(key_pairs)).Struct; + const error_msg = "Each field of '" ++ @typeName(Enum) ++ "' must map to exactly one OID"; + if (!enum_info.is_exhaustive or enum_info.fields.len != struct_info.fields.len) { + @compileError(error_msg); + } + + comptime var enum_to_oid = EnumToOid.initUndefined(); + + const KeyPair = struct { []const u8, Enum }; + comptime var static_key_pairs: [enum_info.fields.len]KeyPair = undefined; + + comptime for (enum_info.fields, 0..) |f, i| { + if (!@hasField(@TypeOf(key_pairs), f.name)) { + @compileError("Field '" ++ f.name ++ "' missing Oid.StaticMap entry"); + } + const encoded = &encodeComptime(@field(key_pairs, f.name)); + const tag: Enum = @enumFromInt(f.value); + static_key_pairs[i] = .{ encoded, tag }; + enum_to_oid.set(tag, encoded); + }; + + const oid_to_enum = std.StaticStringMap(Enum).initComptime(static_key_pairs); + if (oid_to_enum.values().len != enum_info.fields.len) @compileError(error_msg); + + return ReturnType{ .oid_to_enum = oid_to_enum, .enum_to_oid = enum_to_oid }; + } + }; +} + +/// Strictly for testing. +fn hexToBytes(comptime hex: []const u8) [hex.len / 2]u8 { + var res: [hex.len / 2]u8 = undefined; + _ = std.fmt.hexToBytes(&res, hex) catch unreachable; + return res; +} + +const std = @import("std"); +const Oid = @This(); +const Arc = u32; +const encoding_base = 128; +const Allocator = std.mem.Allocator; +const der = @import("der.zig"); +const asn1 = @import("../asn1.zig"); diff --git a/lib/std/crypto/asn1/der.zig b/lib/std/crypto/asn1/der.zig new file mode 100644 index 0000000000..4395f9f3b6 --- /dev/null +++ b/lib/std/crypto/asn1/der.zig @@ -0,0 +1,55 @@ +//! Distinguised Encoding Rules as defined in X.690 and X.691. +//! +//! Subset of Basic Encoding Rules (BER) which eliminates flexibility in +//! an effort to acheive normality. Used in PKI. +const std = @import("std"); +const asn1 = @import("../asn1.zig"); + +pub const Decoder = @import("der/Decoder.zig"); +pub const Encoder = @import("der/Encoder.zig"); + +pub fn decode(comptime T: type, encoded: []const u8) !T { + var decoder = Decoder{ .bytes = encoded }; + const res = try decoder.any(T); + std.debug.assert(decoder.index == encoded.len); + return res; +} + +/// Caller owns returned memory. +pub fn encode(allocator: std.mem.Allocator, value: anytype) ![]u8 { + var encoder = Encoder.init(allocator); + defer encoder.deinit(); + try encoder.any(value); + return try encoder.buffer.toOwnedSlice(); +} + +test encode { + // https://lapo.it/asn1js/#MAgGAyoDBAIBBA + const Value = struct { a: asn1.Oid, b: i32 }; + const test_case = .{ + .value = Value{ .a = asn1.Oid.fromDotComptime("1.2.3.4"), .b = 4 }, + .encoded = &[_]u8{ 0x30, 0x08, 0x06, 0x03, 0x2A, 0x03, 0x04, 0x02, 0x01, 0x04 }, + }; + const allocator = std.testing.allocator; + const actual = try encode(allocator, test_case.value); + defer allocator.free(actual); + + try std.testing.expectEqualSlices(u8, test_case.encoded, actual); +} + +test decode { + // https://lapo.it/asn1js/#MAgGAyoDBAIBBA + const Value = struct { a: asn1.Oid, b: i32 }; + const test_case = .{ + .value = Value{ .a = asn1.Oid.fromDotComptime("1.2.3.4"), .b = 4 }, + .encoded = &[_]u8{ 0x30, 0x08, 0x06, 0x03, 0x2A, 0x03, 0x04, 0x02, 0x01, 0x04 }, + }; + const decoded = try decode(Value, test_case.encoded); + + try std.testing.expectEqualDeep(test_case.value, decoded); +} + +test { + _ = Decoder; + _ = Encoder; +} diff --git a/lib/std/crypto/asn1/der/ArrayListReverse.zig b/lib/std/crypto/asn1/der/ArrayListReverse.zig new file mode 100644 index 0000000000..f580c54546 --- /dev/null +++ b/lib/std/crypto/asn1/der/ArrayListReverse.zig @@ -0,0 +1,97 @@ +//! An ArrayList that grows backwards. Counts nested prefix length fields +//! in O(n) instead of O(n^depth) at the cost of extra buffering. +//! +//! Laid out in memory like: +//! capacity |--------------------------| +//! data |-------------| +data: []u8, +capacity: usize, +allocator: Allocator, + +const ArrayListReverse = @This(); +const Error = Allocator.Error; + +pub fn init(allocator: Allocator) ArrayListReverse { + return .{ .data = &.{}, .capacity = 0, .allocator = allocator }; +} + +pub fn deinit(self: *ArrayListReverse) void { + self.allocator.free(self.allocatedSlice()); +} + +pub fn ensureCapacity(self: *ArrayListReverse, new_capacity: usize) Error!void { + if (self.capacity >= new_capacity) return; + + const old_memory = self.allocatedSlice(); + // Just make a new allocation to not worry about aliasing. + const new_memory = try self.allocator.alloc(u8, new_capacity); + @memcpy(new_memory[new_capacity - self.data.len ..], self.data); + self.allocator.free(old_memory); + self.data.ptr = new_memory.ptr + new_capacity - self.data.len; + self.capacity = new_memory.len; +} + +pub fn prependSlice(self: *ArrayListReverse, data: []const u8) Error!void { + try self.ensureCapacity(self.data.len + data.len); + const old_len = self.data.len; + const new_len = old_len + data.len; + assert(new_len <= self.capacity); + self.data.len = new_len; + + const end = self.data.ptr; + const begin = end - data.len; + const slice = begin[0..data.len]; + @memcpy(slice, data); + self.data.ptr = begin; +} + +pub const Writer = std.io.Writer(*ArrayListReverse, Error, prependSliceSize); +/// Warning: This writer writes backwards. `fn print` will NOT work as expected. +pub fn writer(self: *ArrayListReverse) Writer { + return .{ .context = self }; +} + +fn prependSliceSize(self: *ArrayListReverse, data: []const u8) Error!usize { + try self.prependSlice(data); + return data.len; +} + +fn allocatedSlice(self: *ArrayListReverse) []u8 { + return (self.data.ptr + self.data.len - self.capacity)[0..self.capacity]; +} + +/// Invalidates all element pointers. +pub fn clearAndFree(self: *ArrayListReverse) void { + self.allocator.free(self.allocatedSlice()); + self.data.len = 0; + self.capacity = 0; +} + +/// The caller owns the returned memory. +/// Capacity is cleared, making deinit() safe but unnecessary to call. +pub fn toOwnedSlice(self: *ArrayListReverse) Error![]u8 { + const new_memory = try self.allocator.alloc(u8, self.data.len); + @memcpy(new_memory, self.data); + @memset(self.data, undefined); + self.clearAndFree(); + return new_memory; +} + +const std = @import("std"); +const Allocator = std.mem.Allocator; +const assert = std.debug.assert; +const testing = std.testing; + +test ArrayListReverse { + var b = ArrayListReverse.init(testing.allocator); + defer b.deinit(); + const data: []const u8 = &.{ 4, 5, 6 }; + try b.prependSlice(data); + try testing.expectEqual(data.len, b.data.len); + try testing.expectEqualSlices(u8, data, b.data); + + const data2: []const u8 = &.{ 1, 2, 3 }; + try b.prependSlice(data2); + try testing.expectEqual(data.len + data2.len, b.data.len); + try testing.expectEqualSlices(u8, data2 ++ data, b.data); +} diff --git a/lib/std/crypto/asn1/der/Decoder.zig b/lib/std/crypto/asn1/der/Decoder.zig new file mode 100644 index 0000000000..2eedbee957 --- /dev/null +++ b/lib/std/crypto/asn1/der/Decoder.zig @@ -0,0 +1,172 @@ +//! A secure DER parser that: +//! - Prefers calling `fn decodeDer(self: @This(), decoder: *der.Decoder)` +//! - Does NOT allocate. If you wish to parse lists you can do so lazily +//! with an opaque type. +//! - Does NOT read memory outside `bytes`. +//! - Does NOT return elements with slices outside `bytes`. +//! - Errors on values that do NOT follow DER rules: +//! - Lengths that could be represented in a shorter form. +//! - Booleans that are not 0xff or 0x00. +bytes: []const u8, +index: Index = 0, +/// The field tag of the most recently visited field. +/// This is needed because we might visit an implicitly tagged container with a `fn decodeDer`. +field_tag: ?FieldTag = null, + +/// Expect a value. +pub fn any(self: *Decoder, comptime T: type) !T { + if (std.meta.hasFn(T, "decodeDer")) return try T.decodeDer(self); + + const tag = Tag.fromZig(T).toExpected(); + switch (@typeInfo(T)) { + .Struct => { + const ele = try self.element(tag); + defer self.index = ele.slice.end; // don't force parsing all fields + + var res: T = undefined; + + inline for (std.meta.fields(T)) |f| { + self.field_tag = FieldTag.fromContainer(T, f.name); + + if (self.field_tag) |ft| { + if (ft.explicit) { + const seq = try self.element(ft.toTag().toExpected()); + self.index = seq.slice.start; + self.field_tag = null; + } + } + + @field(res, f.name) = self.any(f.type) catch |err| brk: { + if (f.default_value) |d| { + break :brk @as(*const f.type, @alignCast(@ptrCast(d))).*; + } + return err; + }; + // DER encodes null values by skipping them. + if (@typeInfo(f.type) == .Optional and @field(res, f.name) == null) { + if (f.default_value) |d| { + @field(res, f.name) = @as(*const f.type, @alignCast(@ptrCast(d))).*; + } + } + } + + return res; + }, + .Bool => { + const ele = try self.element(tag); + const bytes = self.view(ele); + if (bytes.len != 1) return error.InvalidBool; + + return switch (bytes[0]) { + 0x00 => false, + 0xff => true, + else => error.InvalidBool, + }; + }, + .Int => { + const ele = try self.element(tag); + const bytes = self.view(ele); + return try int(T, bytes); + }, + .Enum => |e| { + const ele = try self.element(tag); + const bytes = self.view(ele); + if (@hasDecl(T, "oids")) { + return T.oids.oidToEnum(bytes) orelse return error.UnknownOid; + } + return @enumFromInt(try int(e.tag_type, bytes)); + }, + .Optional => |o| return self.any(o.child) catch return null, + else => @compileError("cannot decode type " ++ @typeName(T)), + } +} + +//// Expect a sequence. +pub fn sequence(self: *Decoder) !Element { + return try self.element(ExpectedTag.init(.sequence, true, .universal)); +} + +//// Expect an element. +pub fn element( + self: *Decoder, + expected: ExpectedTag, +) (error{ EndOfStream, UnexpectedElement } || Element.DecodeError)!Element { + if (self.index >= self.bytes.len) return error.EndOfStream; + + const res = try Element.decode(self.bytes, self.index); + var e = expected; + if (self.field_tag) |ft| { + e.number = @enumFromInt(ft.number); + e.class = ft.class; + } + if (!e.match(res.tag)) { + return error.UnexpectedElement; + } + + self.index = if (res.tag.constructed) res.slice.start else res.slice.end; + return res; +} + +/// View of element bytes. +pub fn view(self: Decoder, elem: Element) []const u8 { + return elem.slice.view(self.bytes); +} + +fn int(comptime T: type, value: []const u8) error{ NonCanonical, LargeValue }!T { + if (@typeInfo(T).Int.bits % 8 != 0) @compileError("T must be byte aligned"); + + var bytes = value; + if (bytes.len >= 2) { + if (bytes[0] == 0) { + if (@clz(bytes[1]) > 0) return error.NonCanonical; + bytes.ptr += 1; + } + if (bytes[0] == 0xff and @clz(bytes[1]) == 0) return error.NonCanonical; + } + + if (bytes.len > @sizeOf(T)) return error.LargeValue; + if (@sizeOf(T) == 1) return @bitCast(bytes[0]); + + return std.mem.readVarInt(T, bytes, .big); +} + +test int { + try expectEqual(@as(u8, 1), try int(u8, &[_]u8{1})); + try expectError(error.NonCanonical, int(u8, &[_]u8{ 0, 1 })); + try expectError(error.NonCanonical, int(u8, &[_]u8{ 0xff, 0xff })); + + const big = [_]u8{ 0xef, 0xff }; + try expectError(error.LargeValue, int(u8, &big)); + try expectEqual(0xefff, int(u16, &big)); +} + +test Decoder { + var parser = Decoder{ .bytes = @embedFile("./testdata/id_ecc.pub.der") }; + const seq = try parser.sequence(); + + { + const seq2 = try parser.sequence(); + _ = try parser.element(ExpectedTag.init(.oid, false, .universal)); + _ = try parser.element(ExpectedTag.init(.oid, false, .universal)); + + try std.testing.expectEqual(parser.index, seq2.slice.end); + } + _ = try parser.element(ExpectedTag.init(.bitstring, false, .universal)); + + try std.testing.expectEqual(parser.index, seq.slice.end); + try std.testing.expectEqual(parser.index, parser.bytes.len); +} + +const std = @import("std"); +const builtin = @import("builtin"); +const asn1 = @import("../../asn1.zig"); +const Oid = @import("../Oid.zig"); + +const expectEqual = std.testing.expectEqual; +const expectError = std.testing.expectError; +const Decoder = @This(); +const Index = asn1.Index; +const Tag = asn1.Tag; +const FieldTag = asn1.FieldTag; +const ExpectedTag = asn1.ExpectedTag; +const Element = asn1.Element; diff --git a/lib/std/crypto/asn1/der/Encoder.zig b/lib/std/crypto/asn1/der/Encoder.zig new file mode 100644 index 0000000000..939a1a5aa7 --- /dev/null +++ b/lib/std/crypto/asn1/der/Encoder.zig @@ -0,0 +1,166 @@ +//! A buffered DER encoder. +//! +//! Prefers calling container's `fn encodeDer(self: @This(), encoder: *der.Encoder)`. +//! That function should encode values, lengths, then tags. +buffer: ArrayListReverse, +/// The field tag set by a parent container. +/// This is needed because we might visit an implicitly tagged container with a `fn encodeDer`. +field_tag: ?FieldTag = null, + +pub fn init(allocator: std.mem.Allocator) Encoder { + return Encoder{ .buffer = ArrayListReverse.init(allocator) }; +} + +pub fn deinit(self: *Encoder) void { + self.buffer.deinit(); +} + +/// Encode any value. +pub fn any(self: *Encoder, val: anytype) !void { + const T = @TypeOf(val); + try self.anyTag(Tag.fromZig(T), val); +} + +fn anyTag(self: *Encoder, tag_: Tag, val: anytype) !void { + const T = @TypeOf(val); + if (std.meta.hasFn(T, "encodeDer")) return try val.encodeDer(self); + const start = self.buffer.data.len; + const merged_tag = self.mergedTag(tag_); + + switch (@typeInfo(T)) { + .Struct => |info| { + inline for (0..info.fields.len) |i| { + const f = info.fields[info.fields.len - i - 1]; + const field_val = @field(val, f.name); + const field_tag = FieldTag.fromContainer(T, f.name); + + // > The encoding of a set value or sequence value shall not include an encoding for any + // > component value which is equal to its default value. + const is_default = if (f.is_comptime) false else if (f.default_value) |v| brk: { + const default_val: *const f.type = @alignCast(@ptrCast(v)); + break :brk std.mem.eql(u8, std.mem.asBytes(default_val), std.mem.asBytes(&field_val)); + } else false; + + if (!is_default) { + const start2 = self.buffer.data.len; + self.field_tag = field_tag; + // will merge with self.field_tag. + // may mutate self.field_tag. + try self.anyTag(Tag.fromZig(f.type), field_val); + if (field_tag) |ft| { + if (ft.explicit) { + try self.length(self.buffer.data.len - start2); + try self.tag(ft.toTag()); + self.field_tag = null; + } + } + } + } + }, + .Bool => try self.buffer.prependSlice(&[_]u8{if (val) 0xff else 0}), + .Int => try self.int(T, val), + .Enum => |e| { + if (@hasDecl(T, "oids")) { + return self.any(T.oids.enumToOid(val)); + } else { + try self.int(e.tag_type, @intFromEnum(val)); + } + }, + .Optional => if (val) |v| return try self.anyTag(tag_, v), + .Null => {}, + else => @compileError("cannot encode type " ++ @typeName(T)), + } + + try self.length(self.buffer.data.len - start); + try self.tag(merged_tag); +} + +/// Encode a tag. +pub fn tag(self: *Encoder, tag_: Tag) !void { + const t = self.mergedTag(tag_); + try t.encode(self.writer()); +} + +fn mergedTag(self: *Encoder, tag_: Tag) Tag { + var res = tag_; + if (self.field_tag) |ft| { + if (!ft.explicit) { + res.number = @enumFromInt(ft.number); + res.class = ft.class; + } + } + return res; +} + +/// Encode a length. +pub fn length(self: *Encoder, len: usize) !void { + const writer_ = self.writer(); + if (len < 128) { + try writer_.writeInt(u8, @intCast(len), .big); + return; + } + inline for ([_]type{ u8, u16, u32 }) |T| { + if (len < std.math.maxInt(T)) { + try writer_.writeInt(T, @intCast(len), .big); + try writer_.writeInt(u8, @sizeOf(T) | 0x80, .big); + return; + } + } + return error.InvalidLength; +} + +/// Encode a tag and length-prefixed bytes. +pub fn tagBytes(self: *Encoder, tag_: Tag, bytes: []const u8) !void { + try self.buffer.prependSlice(bytes); + try self.length(bytes.len); + try self.tag(tag_); +} + +/// Warning: This writer writes backwards. `fn print` will NOT work as expected. +pub fn writer(self: *Encoder) ArrayListReverse.Writer { + return self.buffer.writer(); +} + +fn int(self: *Encoder, comptime T: type, value: T) !void { + const big = std.mem.nativeTo(T, value, .big); + const big_bytes = std.mem.asBytes(&big); + + const bits_needed = @bitSizeOf(T) - @clz(value); + const needs_padding: u1 = if (value == 0) + 1 + else if (bits_needed > 8) brk: { + const RightShift = std.meta.Int(.unsigned, @bitSizeOf(@TypeOf(bits_needed)) - 1); + const right_shift: RightShift = @intCast(bits_needed - 9); + break :brk if (value >> right_shift == 0x1ff) 1 else 0; + } else 0; + const bytes_needed = try std.math.divCeil(usize, bits_needed, 8) + needs_padding; + + const writer_ = self.writer(); + for (0..bytes_needed - needs_padding) |i| try writer_.writeByte(big_bytes[big_bytes.len - i - 1]); + if (needs_padding == 1) try writer_.writeByte(0); +} + +test int { + const allocator = std.testing.allocator; + var encoder = Encoder.init(allocator); + defer encoder.deinit(); + + try encoder.int(u8, 0); + try std.testing.expectEqualSlices(u8, &[_]u8{0}, encoder.buffer.data); + + encoder.buffer.clearAndFree(); + try encoder.int(u16, 0x00ff); + try std.testing.expectEqualSlices(u8, &[_]u8{0xff}, encoder.buffer.data); + + encoder.buffer.clearAndFree(); + try encoder.int(u32, 0xffff); + try std.testing.expectEqualSlices(u8, &[_]u8{ 0, 0xff, 0xff }, encoder.buffer.data); +} + +const std = @import("std"); +const Oid = @import("../Oid.zig"); +const asn1 = @import("../../asn1.zig"); +const ArrayListReverse = @import("./ArrayListReverse.zig"); +const Tag = asn1.Tag; +const FieldTag = asn1.FieldTag; +const Encoder = @This(); diff --git a/lib/std/crypto/asn1/der/testdata/all_types.der b/lib/std/crypto/asn1/der/testdata/all_types.der new file mode 100644 index 0000000000..a4f784938b Binary files /dev/null and b/lib/std/crypto/asn1/der/testdata/all_types.der differ diff --git a/lib/std/crypto/asn1/der/testdata/id_ecc.pub.der b/lib/std/crypto/asn1/der/testdata/id_ecc.pub.der new file mode 100644 index 0000000000..3964a8f0fc Binary files /dev/null and b/lib/std/crypto/asn1/der/testdata/id_ecc.pub.der differ diff --git a/lib/std/crypto/asn1/test.zig b/lib/std/crypto/asn1/test.zig new file mode 100644 index 0000000000..261f5a4310 --- /dev/null +++ b/lib/std/crypto/asn1/test.zig @@ -0,0 +1,80 @@ +const std = @import("std"); +const asn1 = @import("../asn1.zig"); + +const der = asn1.der; +const Tag = asn1.Tag; +const FieldTag = asn1.FieldTag; + +/// An example that uses all ASN1 types and available implementation features. +const AllTypes = struct { + a: u8 = 0, + b: asn1.BitString, + c: C, + d: asn1.Opaque(Tag.universal(.string_utf8, false)), + e: asn1.Opaque(Tag.universal(.octetstring, false)), + f: ?u16, + g: ?Nested, + h: asn1.Any, + + pub const asn1_tags = .{ + .a = FieldTag.explicit(0, .context_specific), + .b = FieldTag.explicit(1, .context_specific), + .c = FieldTag.implicit(2, .context_specific), + .g = FieldTag.implicit(3, .context_specific), + }; + + const C = enum { + a, + b, + + pub const oids = asn1.Oid.StaticMap(@This()).initComptime(.{ + .a = "1.2.3.4", + .b = "1.2.3.5", + }); + }; + + const Nested = struct { + inner: Asn1T, + sum: i16, + + const Asn1T = struct { a: u8, b: i16 }; + + pub fn decodeDer(decoder: *der.Decoder) !Nested { + const inner = try decoder.any(Asn1T); + return Nested{ .inner = inner, .sum = inner.a + inner.b }; + } + + pub fn encodeDer(self: Nested, encoder: *der.Encoder) !void { + try encoder.any(self.inner); + } + }; +}; + +test AllTypes { + const expected = AllTypes{ + .a = 2, + .b = asn1.BitString{ .bytes = &[_]u8{ 0x04, 0xa0 } }, + .c = .a, + .d = .{ .bytes = "asdf" }, + .e = .{ .bytes = "fdsa" }, + .f = (1 << 8) + 1, + .g = .{ .inner = .{ .a = 4, .b = 5 }, .sum = 9 }, + .h = .{ .tag = Tag.init(.string_ia5, false, .universal), .bytes = "asdf" }, + }; + // https://lapo.it/asn1js/#MC-gAwIBAqEFAwMABKCCAyoDBAwEYXNkZgQEZmRzYQICAQGjBgIBBAIBBRYEYXNkZg + const path = "./der/testdata/all_types.der"; + const encoded = @embedFile(path); + const actual = try asn1.der.decode(AllTypes, encoded); + try std.testing.expectEqualDeep(expected, actual); + + const allocator = std.testing.allocator; + const buf = try asn1.der.encode(allocator, expected); + defer allocator.free(buf); + try std.testing.expectEqualSlices(u8, encoded, buf); + + // Use this to update test file. + // const dir = try std.fs.cwd().openDir("lib/std/crypto/asn1", .{}); + // var file = try dir.createFile(path, .{}); + // defer file.close(); + // try file.writeAll(buf); +}