diff --git a/lib/std/crypto/utils.zig b/lib/std/crypto/utils.zig index d5d5d9ade3..ca86601bf9 100644 --- a/lib/std/crypto/utils.zig +++ b/lib/std/crypto/utils.zig @@ -1,4 +1,5 @@ const std = @import("../std.zig"); +const debug = std.debug; const mem = std.mem; const testing = std.testing; @@ -43,44 +44,37 @@ pub fn timingSafeEql(comptime T: type, a: T, b: T) bool { /// Compare two integers serialized as arrays of the same size, in constant time. /// Returns .lt if ab and .eq if a=b -pub fn timingSafeCompare(comptime T: type, a: T, b: T, endian: Endian) Order { - switch (@typeInfo(T)) { - .Array => |info| { - const C = info.child; - const bits = switch (@typeInfo(C)) { - .Int => |cinfo| if (cinfo.signedness != .unsigned) @compileError("Elements to be compared must be unsigned") else cinfo.bits, - else => @compileError("Elements to be compared must be integers"), - }; - comptime const Cext = std.meta.Int(.unsigned, bits + 1); - var gt: C = 0; - var eq: C = 1; - if (endian == .Little) { - var i = a.len; - while (i != 0) { - i -= 1; - const x1 = a[i]; - const x2 = b[i]; - gt |= @truncate(C, (@as(Cext, x2) -% @as(Cext, x1)) >> bits) & eq; - eq &= @truncate(C, (@as(Cext, (x2 ^ x1)) -% 1) >> bits); - } - } else { - for (a) |x1, i| { - const x2 = b[i]; - gt |= @truncate(C, (@as(Cext, x2) -% @as(Cext, x1)) >> bits) & eq; - eq &= @truncate(C, (@as(Cext, (x2 ^ x1)) -% 1) >> bits); - } - } - if (gt != 0) { - return Order.gt; - } else if (eq != 0) { - return Order.eq; - } - return Order.lt; - }, - else => { - @compileError("Only arrays can be compared"); - }, +pub fn timingSafeCompare(comptime T: type, a: []const T, b: []const T, endian: Endian) Order { + debug.assert(a.len == b.len); + const bits = switch (@typeInfo(T)) { + .Int => |cinfo| if (cinfo.signedness != .unsigned) @compileError("Elements to be compared must be unsigned") else cinfo.bits, + else => @compileError("Elements to be compared must be integers"), + }; + comptime const Cext = std.meta.Int(.unsigned, bits + 1); + var gt: T = 0; + var eq: T = 1; + if (endian == .Little) { + var i = a.len; + while (i != 0) { + i -= 1; + const x1 = a[i]; + const x2 = b[i]; + gt |= @truncate(T, (@as(Cext, x2) -% @as(Cext, x1)) >> bits) & eq; + eq &= @truncate(T, (@as(Cext, (x2 ^ x1)) -% 1) >> bits); + } + } else { + for (a) |x1, i| { + const x2 = b[i]; + gt |= @truncate(T, (@as(Cext, x2) -% @as(Cext, x1)) >> bits) & eq; + eq &= @truncate(T, (@as(Cext, (x2 ^ x1)) -% 1) >> bits); + } } + if (gt != 0) { + return Order.gt; + } else if (eq != 0) { + return Order.eq; + } + return Order.lt; } /// Sets a slice to zeroes. @@ -118,14 +112,14 @@ test "crypto.utils.timingSafeEql (vectors)" { test "crypto.utils.timingSafeCompare" { var a = [_]u8{10} ** 32; var b = [_]u8{10} ** 32; - testing.expectEqual(timingSafeCompare([32]u8, a, b, .Big), .eq); - testing.expectEqual(timingSafeCompare([32]u8, a, b, .Little), .eq); + testing.expectEqual(timingSafeCompare(u8, &a, &b, .Big), .eq); + testing.expectEqual(timingSafeCompare(u8, &a, &b, .Little), .eq); a[31] = 1; - testing.expectEqual(timingSafeCompare([32]u8, a, b, .Big), .lt); - testing.expectEqual(timingSafeCompare([32]u8, a, b, .Little), .lt); + testing.expectEqual(timingSafeCompare(u8, &a, &b, .Big), .lt); + testing.expectEqual(timingSafeCompare(u8, &a, &b, .Little), .lt); a[0] = 20; - testing.expectEqual(timingSafeCompare([32]u8, a, b, .Big), .gt); - testing.expectEqual(timingSafeCompare([32]u8, a, b, .Little), .lt); + testing.expectEqual(timingSafeCompare(u8, &a, &b, .Big), .gt); + testing.expectEqual(timingSafeCompare(u8, &a, &b, .Little), .lt); } test "crypto.utils.secureZero" {