diff --git a/lib/std/crypto/utils.zig b/lib/std/crypto/utils.zig index 08271ac9f4..d5d5d9ade3 100644 --- a/lib/std/crypto/utils.zig +++ b/lib/std/crypto/utils.zig @@ -2,6 +2,9 @@ const std = @import("../std.zig"); const mem = std.mem; const testing = std.testing; +const Endian = std.builtin.Endian; +const Order = std.math.Order; + /// Compares two arrays in constant time (for a given length) and returns whether they are equal. /// This function was designed to compare short cryptographic secrets (MACs, signatures). /// For all other applications, use mem.eql() instead. @@ -38,6 +41,48 @@ pub fn timingSafeEql(comptime T: type, a: T, b: T) bool { } } +/// Compare two integers serialized as arrays of the same size, in constant time. +/// Returns .lt if ab and .eq if a=b +pub fn timingSafeCompare(comptime T: type, a: T, b: T, endian: Endian) Order { + switch (@typeInfo(T)) { + .Array => |info| { + const C = info.child; + const bits = switch (@typeInfo(C)) { + .Int => |cinfo| if (cinfo.signedness != .unsigned) @compileError("Elements to be compared must be unsigned") else cinfo.bits, + else => @compileError("Elements to be compared must be integers"), + }; + comptime const Cext = std.meta.Int(.unsigned, bits + 1); + var gt: C = 0; + var eq: C = 1; + if (endian == .Little) { + var i = a.len; + while (i != 0) { + i -= 1; + const x1 = a[i]; + const x2 = b[i]; + gt |= @truncate(C, (@as(Cext, x2) -% @as(Cext, x1)) >> bits) & eq; + eq &= @truncate(C, (@as(Cext, (x2 ^ x1)) -% 1) >> bits); + } + } else { + for (a) |x1, i| { + const x2 = b[i]; + gt |= @truncate(C, (@as(Cext, x2) -% @as(Cext, x1)) >> bits) & eq; + eq &= @truncate(C, (@as(Cext, (x2 ^ x1)) -% 1) >> bits); + } + } + if (gt != 0) { + return Order.gt; + } else if (eq != 0) { + return Order.eq; + } + return Order.lt; + }, + else => { + @compileError("Only arrays can be compared"); + }, + } +} + /// Sets a slice to zeroes. /// Prevents the store from being optimized out. pub fn secureZero(comptime T: type, s: []T) void { @@ -70,6 +115,19 @@ test "crypto.utils.timingSafeEql (vectors)" { testing.expect(timingSafeEql(std.meta.Vector(100, u8), v1, v3)); } +test "crypto.utils.timingSafeCompare" { + var a = [_]u8{10} ** 32; + var b = [_]u8{10} ** 32; + testing.expectEqual(timingSafeCompare([32]u8, a, b, .Big), .eq); + testing.expectEqual(timingSafeCompare([32]u8, a, b, .Little), .eq); + a[31] = 1; + testing.expectEqual(timingSafeCompare([32]u8, a, b, .Big), .lt); + testing.expectEqual(timingSafeCompare([32]u8, a, b, .Little), .lt); + a[0] = 20; + testing.expectEqual(timingSafeCompare([32]u8, a, b, .Big), .gt); + testing.expectEqual(timingSafeCompare([32]u8, a, b, .Little), .lt); +} + test "crypto.utils.secureZero" { var a = [_]u8{0xfe} ** 8; var b = [_]u8{0xfe} ** 8;