diff --git a/lib/std/crypto.zig b/lib/std/crypto.zig index f46e7b1022..b469620002 100644 --- a/lib/std/crypto.zig +++ b/lib/std/crypto.zig @@ -1,3 +1,5 @@ +const root = @import("root"); + /// Authenticated Encryption with Associated Data pub const aead = struct { pub const aegis = struct { @@ -183,6 +185,31 @@ pub const errors = @import("crypto/errors.zig"); pub const tls = @import("crypto/tls.zig"); pub const Certificate = @import("crypto/Certificate.zig"); +/// Global configuration of cryptographic implementations in the standard library. +pub const config = struct { + /// Side-channels mitigations. + pub const SideChannelsMitigations = enum { + /// No additional side-channel mitigations are applied. + /// This is the fastest mode. + none, + /// The `basic` mode protects against most practical attacks, provided that the + /// application or implements proper defenses against brute-force attacks. + /// It offers a good balance between performance and security. + basic, + /// The `medium` mode offers increased resilience against side-channel attacks, + /// making most attacks unpractical even on shared/low latency environements. + /// This is the default mode. + medium, + /// The `full` mode offers the highest level of protection against side-channel attacks. + /// Note that this doesn't cover all possible attacks (especially power analysis or + /// thread-local attacks such as cachebleed), and that the performance impact is significant. + full, + }; + + /// This is a global configuration that applies to all cryptographic implementations. + pub const side_channels_mitigations: SideChannelsMitigations = if (@hasDecl(root, "side_channels_mitigations")) root.side_channels_mitigations else .medium; +}; + test { _ = aead.aegis.Aegis128L; _ = aead.aegis.Aegis256; diff --git a/lib/std/crypto/aes/soft.zig b/lib/std/crypto/aes/soft.zig index d8bd3d4ac0..4a300961c6 100644 --- a/lib/std/crypto/aes/soft.zig +++ b/lib/std/crypto/aes/soft.zig @@ -1,11 +1,11 @@ -// Based on Go stdlib implementation - const std = @import("../../std.zig"); const math = std.math; const mem = std.mem; const BlockVec = [4]u32; +const side_channels_mitigations = std.crypto.config.side_channels_mitigations; + /// A single AES block. pub const Block = struct { pub const block_length: usize = 16; @@ -15,20 +15,20 @@ pub const Block = struct { /// Convert a byte sequence into an internal representation. pub inline fn fromBytes(bytes: *const [16]u8) Block { - const s0 = mem.readIntBig(u32, bytes[0..4]); - const s1 = mem.readIntBig(u32, bytes[4..8]); - const s2 = mem.readIntBig(u32, bytes[8..12]); - const s3 = mem.readIntBig(u32, bytes[12..16]); + const s0 = mem.readIntLittle(u32, bytes[0..4]); + const s1 = mem.readIntLittle(u32, bytes[4..8]); + const s2 = mem.readIntLittle(u32, bytes[8..12]); + const s3 = mem.readIntLittle(u32, bytes[12..16]); return Block{ .repr = BlockVec{ s0, s1, s2, s3 } }; } /// Convert the internal representation of a block into a byte sequence. pub inline fn toBytes(block: Block) [16]u8 { var bytes: [16]u8 = undefined; - mem.writeIntBig(u32, bytes[0..4], block.repr[0]); - mem.writeIntBig(u32, bytes[4..8], block.repr[1]); - mem.writeIntBig(u32, bytes[8..12], block.repr[2]); - mem.writeIntBig(u32, bytes[12..16], block.repr[3]); + mem.writeIntLittle(u32, bytes[0..4], block.repr[0]); + mem.writeIntLittle(u32, bytes[4..8], block.repr[1]); + mem.writeIntLittle(u32, bytes[8..12], block.repr[2]); + mem.writeIntLittle(u32, bytes[12..16], block.repr[3]); return bytes; } @@ -50,32 +50,93 @@ pub const Block = struct { const s2 = block.repr[2]; const s3 = block.repr[3]; - const t0 = round_key.repr[0] ^ table_encrypt[0][@truncate(u8, s0 >> 24)] ^ table_encrypt[1][@truncate(u8, s1 >> 16)] ^ table_encrypt[2][@truncate(u8, s2 >> 8)] ^ table_encrypt[3][@truncate(u8, s3)]; - const t1 = round_key.repr[1] ^ table_encrypt[0][@truncate(u8, s1 >> 24)] ^ table_encrypt[1][@truncate(u8, s2 >> 16)] ^ table_encrypt[2][@truncate(u8, s3 >> 8)] ^ table_encrypt[3][@truncate(u8, s0)]; - const t2 = round_key.repr[2] ^ table_encrypt[0][@truncate(u8, s2 >> 24)] ^ table_encrypt[1][@truncate(u8, s3 >> 16)] ^ table_encrypt[2][@truncate(u8, s0 >> 8)] ^ table_encrypt[3][@truncate(u8, s1)]; - const t3 = round_key.repr[3] ^ table_encrypt[0][@truncate(u8, s3 >> 24)] ^ table_encrypt[1][@truncate(u8, s0 >> 16)] ^ table_encrypt[2][@truncate(u8, s1 >> 8)] ^ table_encrypt[3][@truncate(u8, s2)]; + var x: [4]u32 = undefined; + x = table_lookup(&table_encrypt, @truncate(u8, s0), @truncate(u8, s1 >> 8), @truncate(u8, s2 >> 16), @truncate(u8, s3 >> 24)); + var t0 = x[0] ^ x[1] ^ x[2] ^ x[3]; + x = table_lookup(&table_encrypt, @truncate(u8, s1), @truncate(u8, s2 >> 8), @truncate(u8, s3 >> 16), @truncate(u8, s0 >> 24)); + var t1 = x[0] ^ x[1] ^ x[2] ^ x[3]; + x = table_lookup(&table_encrypt, @truncate(u8, s2), @truncate(u8, s3 >> 8), @truncate(u8, s0 >> 16), @truncate(u8, s1 >> 24)); + var t2 = x[0] ^ x[1] ^ x[2] ^ x[3]; + x = table_lookup(&table_encrypt, @truncate(u8, s3), @truncate(u8, s0 >> 8), @truncate(u8, s1 >> 16), @truncate(u8, s2 >> 24)); + var t3 = x[0] ^ x[1] ^ x[2] ^ x[3]; + + t0 ^= round_key.repr[0]; + t1 ^= round_key.repr[1]; + t2 ^= round_key.repr[2]; + t3 ^= round_key.repr[3]; + + return Block{ .repr = BlockVec{ t0, t1, t2, t3 } }; + } + + /// Encrypt a block with a round key *WITHOUT ANY PROTECTION AGAINST SIDE CHANNELS* + pub inline fn encryptUnprotected(block: Block, round_key: Block) Block { + const s0 = block.repr[0]; + const s1 = block.repr[1]; + const s2 = block.repr[2]; + const s3 = block.repr[3]; + + var x: [4]u32 = undefined; + x = .{ + table_encrypt[0][@truncate(u8, s0)], + table_encrypt[1][@truncate(u8, s1 >> 8)], + table_encrypt[2][@truncate(u8, s2 >> 16)], + table_encrypt[3][@truncate(u8, s3 >> 24)], + }; + var t0 = x[0] ^ x[1] ^ x[2] ^ x[3]; + x = .{ + table_encrypt[0][@truncate(u8, s1)], + table_encrypt[1][@truncate(u8, s2 >> 8)], + table_encrypt[2][@truncate(u8, s3 >> 16)], + table_encrypt[3][@truncate(u8, s0 >> 24)], + }; + var t1 = x[0] ^ x[1] ^ x[2] ^ x[3]; + x = .{ + table_encrypt[0][@truncate(u8, s2)], + table_encrypt[1][@truncate(u8, s3 >> 8)], + table_encrypt[2][@truncate(u8, s0 >> 16)], + table_encrypt[3][@truncate(u8, s1 >> 24)], + }; + var t2 = x[0] ^ x[1] ^ x[2] ^ x[3]; + x = .{ + table_encrypt[0][@truncate(u8, s3)], + table_encrypt[1][@truncate(u8, s0 >> 8)], + table_encrypt[2][@truncate(u8, s1 >> 16)], + table_encrypt[3][@truncate(u8, s2 >> 24)], + }; + var t3 = x[0] ^ x[1] ^ x[2] ^ x[3]; + + t0 ^= round_key.repr[0]; + t1 ^= round_key.repr[1]; + t2 ^= round_key.repr[2]; + t3 ^= round_key.repr[3]; return Block{ .repr = BlockVec{ t0, t1, t2, t3 } }; } /// Encrypt a block with the last round key. pub inline fn encryptLast(block: Block, round_key: Block) Block { - const t0 = block.repr[0]; - const t1 = block.repr[1]; - const t2 = block.repr[2]; - const t3 = block.repr[3]; + const s0 = block.repr[0]; + const s1 = block.repr[1]; + const s2 = block.repr[2]; + const s3 = block.repr[3]; // Last round uses s-box directly and XORs to produce output. - var s0 = @as(u32, sbox_encrypt[t0 >> 24]) << 24 | @as(u32, sbox_encrypt[t1 >> 16 & 0xff]) << 16 | @as(u32, sbox_encrypt[t2 >> 8 & 0xff]) << 8 | @as(u32, sbox_encrypt[t3 & 0xff]); - var s1 = @as(u32, sbox_encrypt[t1 >> 24]) << 24 | @as(u32, sbox_encrypt[t2 >> 16 & 0xff]) << 16 | @as(u32, sbox_encrypt[t3 >> 8 & 0xff]) << 8 | @as(u32, sbox_encrypt[t0 & 0xff]); - var s2 = @as(u32, sbox_encrypt[t2 >> 24]) << 24 | @as(u32, sbox_encrypt[t3 >> 16 & 0xff]) << 16 | @as(u32, sbox_encrypt[t0 >> 8 & 0xff]) << 8 | @as(u32, sbox_encrypt[t1 & 0xff]); - var s3 = @as(u32, sbox_encrypt[t3 >> 24]) << 24 | @as(u32, sbox_encrypt[t0 >> 16 & 0xff]) << 16 | @as(u32, sbox_encrypt[t1 >> 8 & 0xff]) << 8 | @as(u32, sbox_encrypt[t2 & 0xff]); - s0 ^= round_key.repr[0]; - s1 ^= round_key.repr[1]; - s2 ^= round_key.repr[2]; - s3 ^= round_key.repr[3]; + var x: [4]u8 = undefined; + x = sbox_lookup(&sbox_encrypt, @truncate(u8, s3 >> 24), @truncate(u8, s2 >> 16), @truncate(u8, s1 >> 8), @truncate(u8, s0)); + var t0 = @as(u32, x[0]) << 24 | @as(u32, x[1]) << 16 | @as(u32, x[2]) << 8 | @as(u32, x[3]); + x = sbox_lookup(&sbox_encrypt, @truncate(u8, s0 >> 24), @truncate(u8, s3 >> 16), @truncate(u8, s2 >> 8), @truncate(u8, s1)); + var t1 = @as(u32, x[0]) << 24 | @as(u32, x[1]) << 16 | @as(u32, x[2]) << 8 | @as(u32, x[3]); + x = sbox_lookup(&sbox_encrypt, @truncate(u8, s1 >> 24), @truncate(u8, s0 >> 16), @truncate(u8, s3 >> 8), @truncate(u8, s2)); + var t2 = @as(u32, x[0]) << 24 | @as(u32, x[1]) << 16 | @as(u32, x[2]) << 8 | @as(u32, x[3]); + x = sbox_lookup(&sbox_encrypt, @truncate(u8, s2 >> 24), @truncate(u8, s1 >> 16), @truncate(u8, s0 >> 8), @truncate(u8, s3)); + var t3 = @as(u32, x[0]) << 24 | @as(u32, x[1]) << 16 | @as(u32, x[2]) << 8 | @as(u32, x[3]); - return Block{ .repr = BlockVec{ s0, s1, s2, s3 } }; + t0 ^= round_key.repr[0]; + t1 ^= round_key.repr[1]; + t2 ^= round_key.repr[2]; + t3 ^= round_key.repr[3]; + + return Block{ .repr = BlockVec{ t0, t1, t2, t3 } }; } /// Decrypt a block with a round key. @@ -85,32 +146,93 @@ pub const Block = struct { const s2 = block.repr[2]; const s3 = block.repr[3]; - const t0 = round_key.repr[0] ^ table_decrypt[0][@truncate(u8, s0 >> 24)] ^ table_decrypt[1][@truncate(u8, s3 >> 16)] ^ table_decrypt[2][@truncate(u8, s2 >> 8)] ^ table_decrypt[3][@truncate(u8, s1)]; - const t1 = round_key.repr[1] ^ table_decrypt[0][@truncate(u8, s1 >> 24)] ^ table_decrypt[1][@truncate(u8, s0 >> 16)] ^ table_decrypt[2][@truncate(u8, s3 >> 8)] ^ table_decrypt[3][@truncate(u8, s2)]; - const t2 = round_key.repr[2] ^ table_decrypt[0][@truncate(u8, s2 >> 24)] ^ table_decrypt[1][@truncate(u8, s1 >> 16)] ^ table_decrypt[2][@truncate(u8, s0 >> 8)] ^ table_decrypt[3][@truncate(u8, s3)]; - const t3 = round_key.repr[3] ^ table_decrypt[0][@truncate(u8, s3 >> 24)] ^ table_decrypt[1][@truncate(u8, s2 >> 16)] ^ table_decrypt[2][@truncate(u8, s1 >> 8)] ^ table_decrypt[3][@truncate(u8, s0)]; + var x: [4]u32 = undefined; + x = table_lookup(&table_decrypt, @truncate(u8, s0), @truncate(u8, s3 >> 8), @truncate(u8, s2 >> 16), @truncate(u8, s1 >> 24)); + var t0 = x[0] ^ x[1] ^ x[2] ^ x[3]; + x = table_lookup(&table_decrypt, @truncate(u8, s1), @truncate(u8, s0 >> 8), @truncate(u8, s3 >> 16), @truncate(u8, s2 >> 24)); + var t1 = x[0] ^ x[1] ^ x[2] ^ x[3]; + x = table_lookup(&table_decrypt, @truncate(u8, s2), @truncate(u8, s1 >> 8), @truncate(u8, s0 >> 16), @truncate(u8, s3 >> 24)); + var t2 = x[0] ^ x[1] ^ x[2] ^ x[3]; + x = table_lookup(&table_decrypt, @truncate(u8, s3), @truncate(u8, s2 >> 8), @truncate(u8, s1 >> 16), @truncate(u8, s0 >> 24)); + var t3 = x[0] ^ x[1] ^ x[2] ^ x[3]; + + t0 ^= round_key.repr[0]; + t1 ^= round_key.repr[1]; + t2 ^= round_key.repr[2]; + t3 ^= round_key.repr[3]; + + return Block{ .repr = BlockVec{ t0, t1, t2, t3 } }; + } + + /// Decrypt a block with a round key *WITHOUT ANY PROTECTION AGAINST SIDE CHANNELS* + pub inline fn decryptUnprotected(block: Block, round_key: Block) Block { + const s0 = block.repr[0]; + const s1 = block.repr[1]; + const s2 = block.repr[2]; + const s3 = block.repr[3]; + + var x: [4]u32 = undefined; + x = .{ + table_decrypt[0][@truncate(u8, s0)], + table_decrypt[1][@truncate(u8, s3 >> 8)], + table_decrypt[2][@truncate(u8, s2 >> 16)], + table_decrypt[3][@truncate(u8, s1 >> 24)], + }; + var t0 = x[0] ^ x[1] ^ x[2] ^ x[3]; + x = .{ + table_decrypt[0][@truncate(u8, s1)], + table_decrypt[1][@truncate(u8, s0 >> 8)], + table_decrypt[2][@truncate(u8, s3 >> 16)], + table_decrypt[3][@truncate(u8, s2 >> 24)], + }; + var t1 = x[0] ^ x[1] ^ x[2] ^ x[3]; + x = .{ + table_decrypt[0][@truncate(u8, s2)], + table_decrypt[1][@truncate(u8, s1 >> 8)], + table_decrypt[2][@truncate(u8, s0 >> 16)], + table_decrypt[3][@truncate(u8, s3 >> 24)], + }; + var t2 = x[0] ^ x[1] ^ x[2] ^ x[3]; + x = .{ + table_decrypt[0][@truncate(u8, s3)], + table_decrypt[1][@truncate(u8, s2 >> 8)], + table_decrypt[2][@truncate(u8, s1 >> 16)], + table_decrypt[3][@truncate(u8, s0 >> 24)], + }; + var t3 = x[0] ^ x[1] ^ x[2] ^ x[3]; + + t0 ^= round_key.repr[0]; + t1 ^= round_key.repr[1]; + t2 ^= round_key.repr[2]; + t3 ^= round_key.repr[3]; return Block{ .repr = BlockVec{ t0, t1, t2, t3 } }; } /// Decrypt a block with the last round key. pub inline fn decryptLast(block: Block, round_key: Block) Block { - const t0 = block.repr[0]; - const t1 = block.repr[1]; - const t2 = block.repr[2]; - const t3 = block.repr[3]; + const s0 = block.repr[0]; + const s1 = block.repr[1]; + const s2 = block.repr[2]; + const s3 = block.repr[3]; // Last round uses s-box directly and XORs to produce output. - var s0 = @as(u32, sbox_decrypt[t0 >> 24]) << 24 | @as(u32, sbox_decrypt[t3 >> 16 & 0xff]) << 16 | @as(u32, sbox_decrypt[t2 >> 8 & 0xff]) << 8 | @as(u32, sbox_decrypt[t1 & 0xff]); - var s1 = @as(u32, sbox_decrypt[t1 >> 24]) << 24 | @as(u32, sbox_decrypt[t0 >> 16 & 0xff]) << 16 | @as(u32, sbox_decrypt[t3 >> 8 & 0xff]) << 8 | @as(u32, sbox_decrypt[t2 & 0xff]); - var s2 = @as(u32, sbox_decrypt[t2 >> 24]) << 24 | @as(u32, sbox_decrypt[t1 >> 16 & 0xff]) << 16 | @as(u32, sbox_decrypt[t0 >> 8 & 0xff]) << 8 | @as(u32, sbox_decrypt[t3 & 0xff]); - var s3 = @as(u32, sbox_decrypt[t3 >> 24]) << 24 | @as(u32, sbox_decrypt[t2 >> 16 & 0xff]) << 16 | @as(u32, sbox_decrypt[t1 >> 8 & 0xff]) << 8 | @as(u32, sbox_decrypt[t0 & 0xff]); - s0 ^= round_key.repr[0]; - s1 ^= round_key.repr[1]; - s2 ^= round_key.repr[2]; - s3 ^= round_key.repr[3]; + var x: [4]u8 = undefined; + x = sbox_lookup(&sbox_decrypt, @truncate(u8, s1 >> 24), @truncate(u8, s2 >> 16), @truncate(u8, s3 >> 8), @truncate(u8, s0)); + var t0 = @as(u32, x[0]) << 24 | @as(u32, x[1]) << 16 | @as(u32, x[2]) << 8 | @as(u32, x[3]); + x = sbox_lookup(&sbox_decrypt, @truncate(u8, s2 >> 24), @truncate(u8, s3 >> 16), @truncate(u8, s0 >> 8), @truncate(u8, s1)); + var t1 = @as(u32, x[0]) << 24 | @as(u32, x[1]) << 16 | @as(u32, x[2]) << 8 | @as(u32, x[3]); + x = sbox_lookup(&sbox_decrypt, @truncate(u8, s3 >> 24), @truncate(u8, s0 >> 16), @truncate(u8, s1 >> 8), @truncate(u8, s2)); + var t2 = @as(u32, x[0]) << 24 | @as(u32, x[1]) << 16 | @as(u32, x[2]) << 8 | @as(u32, x[3]); + x = sbox_lookup(&sbox_decrypt, @truncate(u8, s0 >> 24), @truncate(u8, s1 >> 16), @truncate(u8, s2 >> 8), @truncate(u8, s3)); + var t3 = @as(u32, x[0]) << 24 | @as(u32, x[1]) << 16 | @as(u32, x[2]) << 8 | @as(u32, x[3]); - return Block{ .repr = BlockVec{ s0, s1, s2, s3 } }; + t0 ^= round_key.repr[0]; + t1 ^= round_key.repr[1]; + t2 ^= round_key.repr[2]; + t3 ^= round_key.repr[3]; + + return Block{ .repr = BlockVec{ t0, t1, t2, t3 } }; } /// Apply the bitwise XOR operation to the content of two blocks. @@ -226,7 +348,8 @@ fn KeySchedule(comptime Aes: type) type { const subw = struct { // Apply sbox_encrypt to each byte in w. fn func(w: u32) u32 { - return @as(u32, sbox_encrypt[w >> 24]) << 24 | @as(u32, sbox_encrypt[w >> 16 & 0xff]) << 16 | @as(u32, sbox_encrypt[w >> 8 & 0xff]) << 8 | @as(u32, sbox_encrypt[w & 0xff]); + const x = sbox_lookup(&sbox_key_schedule, @truncate(u8, w), @truncate(u8, w >> 8), @truncate(u8, w >> 16), @truncate(u8, w >> 24)); + return @as(u32, x[3]) << 24 | @as(u32, x[2]) << 16 | @as(u32, x[1]) << 8 | @as(u32, x[0]); } }.func; @@ -244,6 +367,10 @@ fn KeySchedule(comptime Aes: type) type { } round_keys[i / 4].repr[i % 4] = round_keys[(i - words_in_key) / 4].repr[(i - words_in_key) % 4] ^ t; } + i = 0; + inline while (i < round_keys.len * 4) : (i += 1) { + round_keys[i / 4].repr[i % 4] = @byteSwap(round_keys[i / 4].repr[i % 4]); + } return Self{ .round_keys = round_keys }; } @@ -257,11 +384,13 @@ fn KeySchedule(comptime Aes: type) type { const ei = total_words - i - 4; comptime var j: usize = 0; inline while (j < 4) : (j += 1) { - var x = round_keys[(ei + j) / 4].repr[(ei + j) % 4]; + var rk = round_keys[(ei + j) / 4].repr[(ei + j) % 4]; if (i > 0 and i + 4 < total_words) { - x = table_decrypt[0][sbox_encrypt[x >> 24]] ^ table_decrypt[1][sbox_encrypt[x >> 16 & 0xff]] ^ table_decrypt[2][sbox_encrypt[x >> 8 & 0xff]] ^ table_decrypt[3][sbox_encrypt[x & 0xff]]; + const x = sbox_lookup(&sbox_key_schedule, @truncate(u8, rk >> 24), @truncate(u8, rk >> 16), @truncate(u8, rk >> 8), @truncate(u8, rk)); + const y = table_lookup(&table_decrypt, x[3], x[2], x[1], x[0]); + rk = y[0] ^ y[1] ^ y[2] ^ y[3]; } - inv_round_keys[(i + j) / 4].repr[(i + j) % 4] = x; + inv_round_keys[(i + j) / 4].repr[(i + j) % 4] = rk; } } return Self{ .round_keys = inv_round_keys }; @@ -293,7 +422,17 @@ pub fn AesEncryptCtx(comptime Aes: type) type { const round_keys = ctx.key_schedule.round_keys; var t = Block.fromBytes(src).xorBlocks(round_keys[0]); comptime var i = 1; - inline while (i < rounds) : (i += 1) { + if (side_channels_mitigations == .full) { + inline while (i < rounds) : (i += 1) { + t = t.encrypt(round_keys[i]); + } + } else { + inline while (i < 5) : (i += 1) { + t = t.encrypt(round_keys[i]); + } + inline while (i < rounds - 1) : (i += 1) { + t = t.encryptUnprotected(round_keys[i]); + } t = t.encrypt(round_keys[i]); } t = t.encryptLast(round_keys[rounds]); @@ -305,7 +444,17 @@ pub fn AesEncryptCtx(comptime Aes: type) type { const round_keys = ctx.key_schedule.round_keys; var t = Block.fromBytes(&counter).xorBlocks(round_keys[0]); comptime var i = 1; - inline while (i < rounds) : (i += 1) { + if (side_channels_mitigations == .full) { + inline while (i < rounds) : (i += 1) { + t = t.encrypt(round_keys[i]); + } + } else { + inline while (i < 5) : (i += 1) { + t = t.encrypt(round_keys[i]); + } + inline while (i < rounds - 1) : (i += 1) { + t = t.encryptUnprotected(round_keys[i]); + } t = t.encrypt(round_keys[i]); } t = t.encryptLast(round_keys[rounds]); @@ -359,7 +508,17 @@ pub fn AesDecryptCtx(comptime Aes: type) type { const inv_round_keys = ctx.key_schedule.round_keys; var t = Block.fromBytes(src).xorBlocks(inv_round_keys[0]); comptime var i = 1; - inline while (i < rounds) : (i += 1) { + if (side_channels_mitigations == .full) { + inline while (i < rounds) : (i += 1) { + t = t.decrypt(inv_round_keys[i]); + } + } else { + inline while (i < 5) : (i += 1) { + t = t.decrypt(inv_round_keys[i]); + } + inline while (i < rounds - 1) : (i += 1) { + t = t.decryptUnprotected(inv_round_keys[i]); + } t = t.decrypt(inv_round_keys[i]); } t = t.decryptLast(inv_round_keys[rounds]); @@ -428,10 +587,11 @@ const powx = init: { break :init array; }; -const sbox_encrypt align(64) = generateSbox(false); -const sbox_decrypt align(64) = generateSbox(true); -const table_encrypt align(64) = generateTable(false); -const table_decrypt align(64) = generateTable(true); +const sbox_encrypt align(64) = generateSbox(false); // S-box for encryption +const sbox_key_schedule align(64) = generateSbox(false); // S-box only for key schedule, so that it uses distinct L1 cache entries than the S-box used for encryption +const sbox_decrypt align(64) = generateSbox(true); // S-box for decryption +const table_encrypt align(64) = generateTable(false); // 4-byte LUTs for encryption +const table_decrypt align(64) = generateTable(true); // 4-byte LUTs for decryption // Generate S-box substitution values. fn generateSbox(invert: bool) [256]u8 { @@ -472,14 +632,14 @@ fn generateTable(invert: bool) [4][256]u32 { var table: [4][256]u32 = undefined; for (generateSbox(invert), 0..) |value, index| { - table[0][index] = mul(value, if (invert) 0xb else 0x3); - table[0][index] |= math.shl(u32, mul(value, if (invert) 0xd else 0x1), 8); - table[0][index] |= math.shl(u32, mul(value, if (invert) 0x9 else 0x1), 16); - table[0][index] |= math.shl(u32, mul(value, if (invert) 0xe else 0x2), 24); + table[0][index] = math.shl(u32, mul(value, if (invert) 0xb else 0x3), 24); + table[0][index] |= math.shl(u32, mul(value, if (invert) 0xd else 0x1), 16); + table[0][index] |= math.shl(u32, mul(value, if (invert) 0x9 else 0x1), 8); + table[0][index] |= mul(value, if (invert) 0xe else 0x2); - table[1][index] = math.rotr(u32, table[0][index], 8); - table[2][index] = math.rotr(u32, table[0][index], 16); - table[3][index] = math.rotr(u32, table[0][index], 24); + table[1][index] = math.rotl(u32, table[0][index], 8); + table[2][index] = math.rotl(u32, table[0][index], 16); + table[3][index] = math.rotl(u32, table[0][index], 24); } return table; @@ -506,3 +666,82 @@ fn mul(a: u8, b: u8) u8 { return @truncate(u8, s); } + +const cache_line_bytes = 64; + +inline fn sbox_lookup(sbox: *align(64) const [256]u8, idx0: u8, idx1: u8, idx2: u8, idx3: u8) [4]u8 { + if (side_channels_mitigations == .none) { + return [4]u8{ + sbox[idx0], + sbox[idx1], + sbox[idx2], + sbox[idx3], + }; + } else { + const stride = switch (side_channels_mitigations) { + .none => unreachable, + .basic => sbox.len / 4, + .medium => sbox.len / (sbox.len / cache_line_bytes) * 2, + .full => sbox.len / (sbox.len / cache_line_bytes), + }; + const of0 = idx0 % stride; + const of1 = idx1 % stride; + const of2 = idx2 % stride; + const of3 = idx3 % stride; + var t: [4][sbox.len / stride]u8 align(64) = undefined; + var i: usize = 0; + while (i < t[0].len) : (i += 1) { + const tx = sbox[i * stride ..]; + t[0][i] = tx[of0]; + t[1][i] = tx[of1]; + t[2][i] = tx[of2]; + t[3][i] = tx[of3]; + } + std.mem.doNotOptimizeAway(t); + return [4]u8{ + t[0][idx0 / stride], + t[1][idx1 / stride], + t[2][idx2 / stride], + t[3][idx3 / stride], + }; + } +} + +inline fn table_lookup(table: *align(64) const [4][256]u32, idx0: u8, idx1: u8, idx2: u8, idx3: u8) [4]u32 { + if (side_channels_mitigations == .none) { + return [4]u32{ + table[0][idx0], + table[1][idx1], + table[2][idx2], + table[3][idx3], + }; + } else { + const table_bytes = @sizeOf(@TypeOf(table[0])); + const stride = switch (side_channels_mitigations) { + .none => unreachable, + .basic => table[0].len / 4, + .medium => table[0].len / (table_bytes / cache_line_bytes) * 2, + .full => table[0].len / (table_bytes / cache_line_bytes), + }; + const of0 = idx0 % stride; + const of1 = idx1 % stride; + const of2 = idx2 % stride; + const of3 = idx3 % stride; + var t: [4][table[0].len / stride]u32 align(64) = undefined; + var i: usize = 0; + while (i < t[0].len) : (i += 1) { + const tx = table[0][i * stride ..]; + t[0][i] = tx[of0]; + t[1][i] = tx[of1]; + t[2][i] = tx[of2]; + t[3][i] = tx[of3]; + } + std.mem.doNotOptimizeAway(t); + return [4]u32{ + t[0][idx0 / stride], + math.rotl(u32, t[1][idx1 / stride], 8), + math.rotl(u32, t[2][idx2 / stride], 16), + math.rotl(u32, t[3][idx3 / stride], 24), + }; + } +}