Add configurable side channels mitigations; enable them on soft AES (#13739)

* Add configurable side channels mitigations; enable them on soft AES

Our software AES implementation doesn't have any mitigations against
side channels.

Go's generic implementation is not protected at all either, and even
OpenSSL only has minimal mitigations.

Full mitigations against cache-based attacks (bitslicing, fixslicing)
come at a huge performance cost, making AES-based primitives pretty
much useless for many applications. They also don't offer any
protection against other classes of side channel attacks.

In practice, partially protected, or even unprotected implementations
are not as bad as it sounds. Exploiting these side channels requires
an attacker that is able to submit many plaintexts/ciphertexts and
perform accurate measurements. Noisy measurements can still be
exploited, but require a significant amount of attempts. Wether this
is exploitable or not depends on the platform, application and the
attacker's proximity.

So, some libraries made the choice of minimal mitigations and some
use better mitigations in spite of the performance hit. It's a
tradeoff (security vs performance), and there's no one-size-fits all
implementation.

What applies to AES applies to other cryptographic primitives.

For example, RSA signatures are very sensible to fault attacks,
regardless of them using the CRT or not. A mitigation is to verify
every produced signature. That also comes with a performance cost.
Wether to do it or not depends on wether fault attacks are part of
the threat model or not.

Thanks to Zig's comptime, we can try to address these different
requirements.

This PR adds a `side_channels_protection` global, that can later
be complemented with `fault_attacks_protection` and possibly other
knobs.

It can have 4 different values:

- `none`: which doesn't enable additional mitigations.
"Additional", because it only disables mitigations that don't have
a big performance cost. For example, checking authentication tags
will still be done in constant time.

- `basic`: which enables mitigations protecting against attacks in
a common scenario, where an attacker doesn't have physical access to
the device, cannot run arbitrary code on the same thread, and cannot
conduct brute-force attacks without being throttled.

- `medium`: which enables additional mitigations, offering practical
protection in a shared environement.

- `full`: which enables all the mitigations we have.

The tradeoff is that the more mitigations we enable, the bigger the
performance hit will be. But this let applications choose what's
best for their use case.

`medium` is the default.

Currently, this only affects software AES, but that setting can
later be used by other primitives.

For AES, our implementation is a traditional table-based, with 4
32-bit tables and a sbox.

Lookups in that table have been replaced by function calls. These
functions can add a configurable noise level, making cache-based
attacks more difficult to conduct.

In the `none` mitigation level, the behavior is exactly the same
as before. Performance also remains the same.

In other levels, we compress the T tables into a single one, and
read data from multiple cache lines (all of them in `full` mode),
for all bytes in parallel. More precise measurements and way more
attempts become necessary in order to find correlations.

In addition, we use distinct copies of the sbox for key expansion
and encryption, so that they don't share the same L1 cache entries.

The best known attacks target the first two AES round, or the last
one.

While future attacks may improve on this, AES achieves full
diffusion after 4 rounds. So, we can relax the mitigations after
that. This is what this implementation does, enabling mitigations
again for the last two rounds.

In `full` mode, all the rounds are protected.

The protection assumes that lookups within a cache line are secret.
The cachebleed attack showed that it can be circumvented, but
that requires an attacker to be able to abuse hyperthreading and
run code on the same core as the encryption, which is rarely a
practical scenario.

Still, the current AES API allows us to transparently switch to
using fixslicing/bitslicing later when the `full` mitigation level
is enabled.

* Software AES: use little-endian representation.

Virtually all platforms are little-endian these days, so optimizing
for big-endian CPUs doesn't make sense any more.
This commit is contained in:
Frank Denis 2023-03-13 22:18:26 +01:00 committed by GitHub
parent d525ecb523
commit 9622991578
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 328 additions and 62 deletions

View File

@ -1,3 +1,5 @@
const root = @import("root");
/// Authenticated Encryption with Associated Data
pub const aead = struct {
pub const aegis = struct {
@ -183,6 +185,31 @@ pub const errors = @import("crypto/errors.zig");
pub const tls = @import("crypto/tls.zig");
pub const Certificate = @import("crypto/Certificate.zig");
/// Global configuration of cryptographic implementations in the standard library.
pub const config = struct {
/// Side-channels mitigations.
pub const SideChannelsMitigations = enum {
/// No additional side-channel mitigations are applied.
/// This is the fastest mode.
none,
/// The `basic` mode protects against most practical attacks, provided that the
/// application or implements proper defenses against brute-force attacks.
/// It offers a good balance between performance and security.
basic,
/// The `medium` mode offers increased resilience against side-channel attacks,
/// making most attacks unpractical even on shared/low latency environements.
/// This is the default mode.
medium,
/// The `full` mode offers the highest level of protection against side-channel attacks.
/// Note that this doesn't cover all possible attacks (especially power analysis or
/// thread-local attacks such as cachebleed), and that the performance impact is significant.
full,
};
/// This is a global configuration that applies to all cryptographic implementations.
pub const side_channels_mitigations: SideChannelsMitigations = if (@hasDecl(root, "side_channels_mitigations")) root.side_channels_mitigations else .medium;
};
test {
_ = aead.aegis.Aegis128L;
_ = aead.aegis.Aegis256;

View File

@ -1,11 +1,11 @@
// Based on Go stdlib implementation
const std = @import("../../std.zig");
const math = std.math;
const mem = std.mem;
const BlockVec = [4]u32;
const side_channels_mitigations = std.crypto.config.side_channels_mitigations;
/// A single AES block.
pub const Block = struct {
pub const block_length: usize = 16;
@ -15,20 +15,20 @@ pub const Block = struct {
/// Convert a byte sequence into an internal representation.
pub inline fn fromBytes(bytes: *const [16]u8) Block {
const s0 = mem.readIntBig(u32, bytes[0..4]);
const s1 = mem.readIntBig(u32, bytes[4..8]);
const s2 = mem.readIntBig(u32, bytes[8..12]);
const s3 = mem.readIntBig(u32, bytes[12..16]);
const s0 = mem.readIntLittle(u32, bytes[0..4]);
const s1 = mem.readIntLittle(u32, bytes[4..8]);
const s2 = mem.readIntLittle(u32, bytes[8..12]);
const s3 = mem.readIntLittle(u32, bytes[12..16]);
return Block{ .repr = BlockVec{ s0, s1, s2, s3 } };
}
/// Convert the internal representation of a block into a byte sequence.
pub inline fn toBytes(block: Block) [16]u8 {
var bytes: [16]u8 = undefined;
mem.writeIntBig(u32, bytes[0..4], block.repr[0]);
mem.writeIntBig(u32, bytes[4..8], block.repr[1]);
mem.writeIntBig(u32, bytes[8..12], block.repr[2]);
mem.writeIntBig(u32, bytes[12..16], block.repr[3]);
mem.writeIntLittle(u32, bytes[0..4], block.repr[0]);
mem.writeIntLittle(u32, bytes[4..8], block.repr[1]);
mem.writeIntLittle(u32, bytes[8..12], block.repr[2]);
mem.writeIntLittle(u32, bytes[12..16], block.repr[3]);
return bytes;
}
@ -50,32 +50,93 @@ pub const Block = struct {
const s2 = block.repr[2];
const s3 = block.repr[3];
const t0 = round_key.repr[0] ^ table_encrypt[0][@truncate(u8, s0 >> 24)] ^ table_encrypt[1][@truncate(u8, s1 >> 16)] ^ table_encrypt[2][@truncate(u8, s2 >> 8)] ^ table_encrypt[3][@truncate(u8, s3)];
const t1 = round_key.repr[1] ^ table_encrypt[0][@truncate(u8, s1 >> 24)] ^ table_encrypt[1][@truncate(u8, s2 >> 16)] ^ table_encrypt[2][@truncate(u8, s3 >> 8)] ^ table_encrypt[3][@truncate(u8, s0)];
const t2 = round_key.repr[2] ^ table_encrypt[0][@truncate(u8, s2 >> 24)] ^ table_encrypt[1][@truncate(u8, s3 >> 16)] ^ table_encrypt[2][@truncate(u8, s0 >> 8)] ^ table_encrypt[3][@truncate(u8, s1)];
const t3 = round_key.repr[3] ^ table_encrypt[0][@truncate(u8, s3 >> 24)] ^ table_encrypt[1][@truncate(u8, s0 >> 16)] ^ table_encrypt[2][@truncate(u8, s1 >> 8)] ^ table_encrypt[3][@truncate(u8, s2)];
var x: [4]u32 = undefined;
x = table_lookup(&table_encrypt, @truncate(u8, s0), @truncate(u8, s1 >> 8), @truncate(u8, s2 >> 16), @truncate(u8, s3 >> 24));
var t0 = x[0] ^ x[1] ^ x[2] ^ x[3];
x = table_lookup(&table_encrypt, @truncate(u8, s1), @truncate(u8, s2 >> 8), @truncate(u8, s3 >> 16), @truncate(u8, s0 >> 24));
var t1 = x[0] ^ x[1] ^ x[2] ^ x[3];
x = table_lookup(&table_encrypt, @truncate(u8, s2), @truncate(u8, s3 >> 8), @truncate(u8, s0 >> 16), @truncate(u8, s1 >> 24));
var t2 = x[0] ^ x[1] ^ x[2] ^ x[3];
x = table_lookup(&table_encrypt, @truncate(u8, s3), @truncate(u8, s0 >> 8), @truncate(u8, s1 >> 16), @truncate(u8, s2 >> 24));
var t3 = x[0] ^ x[1] ^ x[2] ^ x[3];
t0 ^= round_key.repr[0];
t1 ^= round_key.repr[1];
t2 ^= round_key.repr[2];
t3 ^= round_key.repr[3];
return Block{ .repr = BlockVec{ t0, t1, t2, t3 } };
}
/// Encrypt a block with a round key *WITHOUT ANY PROTECTION AGAINST SIDE CHANNELS*
pub inline fn encryptUnprotected(block: Block, round_key: Block) Block {
const s0 = block.repr[0];
const s1 = block.repr[1];
const s2 = block.repr[2];
const s3 = block.repr[3];
var x: [4]u32 = undefined;
x = .{
table_encrypt[0][@truncate(u8, s0)],
table_encrypt[1][@truncate(u8, s1 >> 8)],
table_encrypt[2][@truncate(u8, s2 >> 16)],
table_encrypt[3][@truncate(u8, s3 >> 24)],
};
var t0 = x[0] ^ x[1] ^ x[2] ^ x[3];
x = .{
table_encrypt[0][@truncate(u8, s1)],
table_encrypt[1][@truncate(u8, s2 >> 8)],
table_encrypt[2][@truncate(u8, s3 >> 16)],
table_encrypt[3][@truncate(u8, s0 >> 24)],
};
var t1 = x[0] ^ x[1] ^ x[2] ^ x[3];
x = .{
table_encrypt[0][@truncate(u8, s2)],
table_encrypt[1][@truncate(u8, s3 >> 8)],
table_encrypt[2][@truncate(u8, s0 >> 16)],
table_encrypt[3][@truncate(u8, s1 >> 24)],
};
var t2 = x[0] ^ x[1] ^ x[2] ^ x[3];
x = .{
table_encrypt[0][@truncate(u8, s3)],
table_encrypt[1][@truncate(u8, s0 >> 8)],
table_encrypt[2][@truncate(u8, s1 >> 16)],
table_encrypt[3][@truncate(u8, s2 >> 24)],
};
var t3 = x[0] ^ x[1] ^ x[2] ^ x[3];
t0 ^= round_key.repr[0];
t1 ^= round_key.repr[1];
t2 ^= round_key.repr[2];
t3 ^= round_key.repr[3];
return Block{ .repr = BlockVec{ t0, t1, t2, t3 } };
}
/// Encrypt a block with the last round key.
pub inline fn encryptLast(block: Block, round_key: Block) Block {
const t0 = block.repr[0];
const t1 = block.repr[1];
const t2 = block.repr[2];
const t3 = block.repr[3];
const s0 = block.repr[0];
const s1 = block.repr[1];
const s2 = block.repr[2];
const s3 = block.repr[3];
// Last round uses s-box directly and XORs to produce output.
var s0 = @as(u32, sbox_encrypt[t0 >> 24]) << 24 | @as(u32, sbox_encrypt[t1 >> 16 & 0xff]) << 16 | @as(u32, sbox_encrypt[t2 >> 8 & 0xff]) << 8 | @as(u32, sbox_encrypt[t3 & 0xff]);
var s1 = @as(u32, sbox_encrypt[t1 >> 24]) << 24 | @as(u32, sbox_encrypt[t2 >> 16 & 0xff]) << 16 | @as(u32, sbox_encrypt[t3 >> 8 & 0xff]) << 8 | @as(u32, sbox_encrypt[t0 & 0xff]);
var s2 = @as(u32, sbox_encrypt[t2 >> 24]) << 24 | @as(u32, sbox_encrypt[t3 >> 16 & 0xff]) << 16 | @as(u32, sbox_encrypt[t0 >> 8 & 0xff]) << 8 | @as(u32, sbox_encrypt[t1 & 0xff]);
var s3 = @as(u32, sbox_encrypt[t3 >> 24]) << 24 | @as(u32, sbox_encrypt[t0 >> 16 & 0xff]) << 16 | @as(u32, sbox_encrypt[t1 >> 8 & 0xff]) << 8 | @as(u32, sbox_encrypt[t2 & 0xff]);
s0 ^= round_key.repr[0];
s1 ^= round_key.repr[1];
s2 ^= round_key.repr[2];
s3 ^= round_key.repr[3];
var x: [4]u8 = undefined;
x = sbox_lookup(&sbox_encrypt, @truncate(u8, s3 >> 24), @truncate(u8, s2 >> 16), @truncate(u8, s1 >> 8), @truncate(u8, s0));
var t0 = @as(u32, x[0]) << 24 | @as(u32, x[1]) << 16 | @as(u32, x[2]) << 8 | @as(u32, x[3]);
x = sbox_lookup(&sbox_encrypt, @truncate(u8, s0 >> 24), @truncate(u8, s3 >> 16), @truncate(u8, s2 >> 8), @truncate(u8, s1));
var t1 = @as(u32, x[0]) << 24 | @as(u32, x[1]) << 16 | @as(u32, x[2]) << 8 | @as(u32, x[3]);
x = sbox_lookup(&sbox_encrypt, @truncate(u8, s1 >> 24), @truncate(u8, s0 >> 16), @truncate(u8, s3 >> 8), @truncate(u8, s2));
var t2 = @as(u32, x[0]) << 24 | @as(u32, x[1]) << 16 | @as(u32, x[2]) << 8 | @as(u32, x[3]);
x = sbox_lookup(&sbox_encrypt, @truncate(u8, s2 >> 24), @truncate(u8, s1 >> 16), @truncate(u8, s0 >> 8), @truncate(u8, s3));
var t3 = @as(u32, x[0]) << 24 | @as(u32, x[1]) << 16 | @as(u32, x[2]) << 8 | @as(u32, x[3]);
return Block{ .repr = BlockVec{ s0, s1, s2, s3 } };
t0 ^= round_key.repr[0];
t1 ^= round_key.repr[1];
t2 ^= round_key.repr[2];
t3 ^= round_key.repr[3];
return Block{ .repr = BlockVec{ t0, t1, t2, t3 } };
}
/// Decrypt a block with a round key.
@ -85,32 +146,93 @@ pub const Block = struct {
const s2 = block.repr[2];
const s3 = block.repr[3];
const t0 = round_key.repr[0] ^ table_decrypt[0][@truncate(u8, s0 >> 24)] ^ table_decrypt[1][@truncate(u8, s3 >> 16)] ^ table_decrypt[2][@truncate(u8, s2 >> 8)] ^ table_decrypt[3][@truncate(u8, s1)];
const t1 = round_key.repr[1] ^ table_decrypt[0][@truncate(u8, s1 >> 24)] ^ table_decrypt[1][@truncate(u8, s0 >> 16)] ^ table_decrypt[2][@truncate(u8, s3 >> 8)] ^ table_decrypt[3][@truncate(u8, s2)];
const t2 = round_key.repr[2] ^ table_decrypt[0][@truncate(u8, s2 >> 24)] ^ table_decrypt[1][@truncate(u8, s1 >> 16)] ^ table_decrypt[2][@truncate(u8, s0 >> 8)] ^ table_decrypt[3][@truncate(u8, s3)];
const t3 = round_key.repr[3] ^ table_decrypt[0][@truncate(u8, s3 >> 24)] ^ table_decrypt[1][@truncate(u8, s2 >> 16)] ^ table_decrypt[2][@truncate(u8, s1 >> 8)] ^ table_decrypt[3][@truncate(u8, s0)];
var x: [4]u32 = undefined;
x = table_lookup(&table_decrypt, @truncate(u8, s0), @truncate(u8, s3 >> 8), @truncate(u8, s2 >> 16), @truncate(u8, s1 >> 24));
var t0 = x[0] ^ x[1] ^ x[2] ^ x[3];
x = table_lookup(&table_decrypt, @truncate(u8, s1), @truncate(u8, s0 >> 8), @truncate(u8, s3 >> 16), @truncate(u8, s2 >> 24));
var t1 = x[0] ^ x[1] ^ x[2] ^ x[3];
x = table_lookup(&table_decrypt, @truncate(u8, s2), @truncate(u8, s1 >> 8), @truncate(u8, s0 >> 16), @truncate(u8, s3 >> 24));
var t2 = x[0] ^ x[1] ^ x[2] ^ x[3];
x = table_lookup(&table_decrypt, @truncate(u8, s3), @truncate(u8, s2 >> 8), @truncate(u8, s1 >> 16), @truncate(u8, s0 >> 24));
var t3 = x[0] ^ x[1] ^ x[2] ^ x[3];
t0 ^= round_key.repr[0];
t1 ^= round_key.repr[1];
t2 ^= round_key.repr[2];
t3 ^= round_key.repr[3];
return Block{ .repr = BlockVec{ t0, t1, t2, t3 } };
}
/// Decrypt a block with a round key *WITHOUT ANY PROTECTION AGAINST SIDE CHANNELS*
pub inline fn decryptUnprotected(block: Block, round_key: Block) Block {
const s0 = block.repr[0];
const s1 = block.repr[1];
const s2 = block.repr[2];
const s3 = block.repr[3];
var x: [4]u32 = undefined;
x = .{
table_decrypt[0][@truncate(u8, s0)],
table_decrypt[1][@truncate(u8, s3 >> 8)],
table_decrypt[2][@truncate(u8, s2 >> 16)],
table_decrypt[3][@truncate(u8, s1 >> 24)],
};
var t0 = x[0] ^ x[1] ^ x[2] ^ x[3];
x = .{
table_decrypt[0][@truncate(u8, s1)],
table_decrypt[1][@truncate(u8, s0 >> 8)],
table_decrypt[2][@truncate(u8, s3 >> 16)],
table_decrypt[3][@truncate(u8, s2 >> 24)],
};
var t1 = x[0] ^ x[1] ^ x[2] ^ x[3];
x = .{
table_decrypt[0][@truncate(u8, s2)],
table_decrypt[1][@truncate(u8, s1 >> 8)],
table_decrypt[2][@truncate(u8, s0 >> 16)],
table_decrypt[3][@truncate(u8, s3 >> 24)],
};
var t2 = x[0] ^ x[1] ^ x[2] ^ x[3];
x = .{
table_decrypt[0][@truncate(u8, s3)],
table_decrypt[1][@truncate(u8, s2 >> 8)],
table_decrypt[2][@truncate(u8, s1 >> 16)],
table_decrypt[3][@truncate(u8, s0 >> 24)],
};
var t3 = x[0] ^ x[1] ^ x[2] ^ x[3];
t0 ^= round_key.repr[0];
t1 ^= round_key.repr[1];
t2 ^= round_key.repr[2];
t3 ^= round_key.repr[3];
return Block{ .repr = BlockVec{ t0, t1, t2, t3 } };
}
/// Decrypt a block with the last round key.
pub inline fn decryptLast(block: Block, round_key: Block) Block {
const t0 = block.repr[0];
const t1 = block.repr[1];
const t2 = block.repr[2];
const t3 = block.repr[3];
const s0 = block.repr[0];
const s1 = block.repr[1];
const s2 = block.repr[2];
const s3 = block.repr[3];
// Last round uses s-box directly and XORs to produce output.
var s0 = @as(u32, sbox_decrypt[t0 >> 24]) << 24 | @as(u32, sbox_decrypt[t3 >> 16 & 0xff]) << 16 | @as(u32, sbox_decrypt[t2 >> 8 & 0xff]) << 8 | @as(u32, sbox_decrypt[t1 & 0xff]);
var s1 = @as(u32, sbox_decrypt[t1 >> 24]) << 24 | @as(u32, sbox_decrypt[t0 >> 16 & 0xff]) << 16 | @as(u32, sbox_decrypt[t3 >> 8 & 0xff]) << 8 | @as(u32, sbox_decrypt[t2 & 0xff]);
var s2 = @as(u32, sbox_decrypt[t2 >> 24]) << 24 | @as(u32, sbox_decrypt[t1 >> 16 & 0xff]) << 16 | @as(u32, sbox_decrypt[t0 >> 8 & 0xff]) << 8 | @as(u32, sbox_decrypt[t3 & 0xff]);
var s3 = @as(u32, sbox_decrypt[t3 >> 24]) << 24 | @as(u32, sbox_decrypt[t2 >> 16 & 0xff]) << 16 | @as(u32, sbox_decrypt[t1 >> 8 & 0xff]) << 8 | @as(u32, sbox_decrypt[t0 & 0xff]);
s0 ^= round_key.repr[0];
s1 ^= round_key.repr[1];
s2 ^= round_key.repr[2];
s3 ^= round_key.repr[3];
var x: [4]u8 = undefined;
x = sbox_lookup(&sbox_decrypt, @truncate(u8, s1 >> 24), @truncate(u8, s2 >> 16), @truncate(u8, s3 >> 8), @truncate(u8, s0));
var t0 = @as(u32, x[0]) << 24 | @as(u32, x[1]) << 16 | @as(u32, x[2]) << 8 | @as(u32, x[3]);
x = sbox_lookup(&sbox_decrypt, @truncate(u8, s2 >> 24), @truncate(u8, s3 >> 16), @truncate(u8, s0 >> 8), @truncate(u8, s1));
var t1 = @as(u32, x[0]) << 24 | @as(u32, x[1]) << 16 | @as(u32, x[2]) << 8 | @as(u32, x[3]);
x = sbox_lookup(&sbox_decrypt, @truncate(u8, s3 >> 24), @truncate(u8, s0 >> 16), @truncate(u8, s1 >> 8), @truncate(u8, s2));
var t2 = @as(u32, x[0]) << 24 | @as(u32, x[1]) << 16 | @as(u32, x[2]) << 8 | @as(u32, x[3]);
x = sbox_lookup(&sbox_decrypt, @truncate(u8, s0 >> 24), @truncate(u8, s1 >> 16), @truncate(u8, s2 >> 8), @truncate(u8, s3));
var t3 = @as(u32, x[0]) << 24 | @as(u32, x[1]) << 16 | @as(u32, x[2]) << 8 | @as(u32, x[3]);
return Block{ .repr = BlockVec{ s0, s1, s2, s3 } };
t0 ^= round_key.repr[0];
t1 ^= round_key.repr[1];
t2 ^= round_key.repr[2];
t3 ^= round_key.repr[3];
return Block{ .repr = BlockVec{ t0, t1, t2, t3 } };
}
/// Apply the bitwise XOR operation to the content of two blocks.
@ -226,7 +348,8 @@ fn KeySchedule(comptime Aes: type) type {
const subw = struct {
// Apply sbox_encrypt to each byte in w.
fn func(w: u32) u32 {
return @as(u32, sbox_encrypt[w >> 24]) << 24 | @as(u32, sbox_encrypt[w >> 16 & 0xff]) << 16 | @as(u32, sbox_encrypt[w >> 8 & 0xff]) << 8 | @as(u32, sbox_encrypt[w & 0xff]);
const x = sbox_lookup(&sbox_key_schedule, @truncate(u8, w), @truncate(u8, w >> 8), @truncate(u8, w >> 16), @truncate(u8, w >> 24));
return @as(u32, x[3]) << 24 | @as(u32, x[2]) << 16 | @as(u32, x[1]) << 8 | @as(u32, x[0]);
}
}.func;
@ -244,6 +367,10 @@ fn KeySchedule(comptime Aes: type) type {
}
round_keys[i / 4].repr[i % 4] = round_keys[(i - words_in_key) / 4].repr[(i - words_in_key) % 4] ^ t;
}
i = 0;
inline while (i < round_keys.len * 4) : (i += 1) {
round_keys[i / 4].repr[i % 4] = @byteSwap(round_keys[i / 4].repr[i % 4]);
}
return Self{ .round_keys = round_keys };
}
@ -257,11 +384,13 @@ fn KeySchedule(comptime Aes: type) type {
const ei = total_words - i - 4;
comptime var j: usize = 0;
inline while (j < 4) : (j += 1) {
var x = round_keys[(ei + j) / 4].repr[(ei + j) % 4];
var rk = round_keys[(ei + j) / 4].repr[(ei + j) % 4];
if (i > 0 and i + 4 < total_words) {
x = table_decrypt[0][sbox_encrypt[x >> 24]] ^ table_decrypt[1][sbox_encrypt[x >> 16 & 0xff]] ^ table_decrypt[2][sbox_encrypt[x >> 8 & 0xff]] ^ table_decrypt[3][sbox_encrypt[x & 0xff]];
const x = sbox_lookup(&sbox_key_schedule, @truncate(u8, rk >> 24), @truncate(u8, rk >> 16), @truncate(u8, rk >> 8), @truncate(u8, rk));
const y = table_lookup(&table_decrypt, x[3], x[2], x[1], x[0]);
rk = y[0] ^ y[1] ^ y[2] ^ y[3];
}
inv_round_keys[(i + j) / 4].repr[(i + j) % 4] = x;
inv_round_keys[(i + j) / 4].repr[(i + j) % 4] = rk;
}
}
return Self{ .round_keys = inv_round_keys };
@ -293,7 +422,17 @@ pub fn AesEncryptCtx(comptime Aes: type) type {
const round_keys = ctx.key_schedule.round_keys;
var t = Block.fromBytes(src).xorBlocks(round_keys[0]);
comptime var i = 1;
inline while (i < rounds) : (i += 1) {
if (side_channels_mitigations == .full) {
inline while (i < rounds) : (i += 1) {
t = t.encrypt(round_keys[i]);
}
} else {
inline while (i < 5) : (i += 1) {
t = t.encrypt(round_keys[i]);
}
inline while (i < rounds - 1) : (i += 1) {
t = t.encryptUnprotected(round_keys[i]);
}
t = t.encrypt(round_keys[i]);
}
t = t.encryptLast(round_keys[rounds]);
@ -305,7 +444,17 @@ pub fn AesEncryptCtx(comptime Aes: type) type {
const round_keys = ctx.key_schedule.round_keys;
var t = Block.fromBytes(&counter).xorBlocks(round_keys[0]);
comptime var i = 1;
inline while (i < rounds) : (i += 1) {
if (side_channels_mitigations == .full) {
inline while (i < rounds) : (i += 1) {
t = t.encrypt(round_keys[i]);
}
} else {
inline while (i < 5) : (i += 1) {
t = t.encrypt(round_keys[i]);
}
inline while (i < rounds - 1) : (i += 1) {
t = t.encryptUnprotected(round_keys[i]);
}
t = t.encrypt(round_keys[i]);
}
t = t.encryptLast(round_keys[rounds]);
@ -359,7 +508,17 @@ pub fn AesDecryptCtx(comptime Aes: type) type {
const inv_round_keys = ctx.key_schedule.round_keys;
var t = Block.fromBytes(src).xorBlocks(inv_round_keys[0]);
comptime var i = 1;
inline while (i < rounds) : (i += 1) {
if (side_channels_mitigations == .full) {
inline while (i < rounds) : (i += 1) {
t = t.decrypt(inv_round_keys[i]);
}
} else {
inline while (i < 5) : (i += 1) {
t = t.decrypt(inv_round_keys[i]);
}
inline while (i < rounds - 1) : (i += 1) {
t = t.decryptUnprotected(inv_round_keys[i]);
}
t = t.decrypt(inv_round_keys[i]);
}
t = t.decryptLast(inv_round_keys[rounds]);
@ -428,10 +587,11 @@ const powx = init: {
break :init array;
};
const sbox_encrypt align(64) = generateSbox(false);
const sbox_decrypt align(64) = generateSbox(true);
const table_encrypt align(64) = generateTable(false);
const table_decrypt align(64) = generateTable(true);
const sbox_encrypt align(64) = generateSbox(false); // S-box for encryption
const sbox_key_schedule align(64) = generateSbox(false); // S-box only for key schedule, so that it uses distinct L1 cache entries than the S-box used for encryption
const sbox_decrypt align(64) = generateSbox(true); // S-box for decryption
const table_encrypt align(64) = generateTable(false); // 4-byte LUTs for encryption
const table_decrypt align(64) = generateTable(true); // 4-byte LUTs for decryption
// Generate S-box substitution values.
fn generateSbox(invert: bool) [256]u8 {
@ -472,14 +632,14 @@ fn generateTable(invert: bool) [4][256]u32 {
var table: [4][256]u32 = undefined;
for (generateSbox(invert), 0..) |value, index| {
table[0][index] = mul(value, if (invert) 0xb else 0x3);
table[0][index] |= math.shl(u32, mul(value, if (invert) 0xd else 0x1), 8);
table[0][index] |= math.shl(u32, mul(value, if (invert) 0x9 else 0x1), 16);
table[0][index] |= math.shl(u32, mul(value, if (invert) 0xe else 0x2), 24);
table[0][index] = math.shl(u32, mul(value, if (invert) 0xb else 0x3), 24);
table[0][index] |= math.shl(u32, mul(value, if (invert) 0xd else 0x1), 16);
table[0][index] |= math.shl(u32, mul(value, if (invert) 0x9 else 0x1), 8);
table[0][index] |= mul(value, if (invert) 0xe else 0x2);
table[1][index] = math.rotr(u32, table[0][index], 8);
table[2][index] = math.rotr(u32, table[0][index], 16);
table[3][index] = math.rotr(u32, table[0][index], 24);
table[1][index] = math.rotl(u32, table[0][index], 8);
table[2][index] = math.rotl(u32, table[0][index], 16);
table[3][index] = math.rotl(u32, table[0][index], 24);
}
return table;
@ -506,3 +666,82 @@ fn mul(a: u8, b: u8) u8 {
return @truncate(u8, s);
}
const cache_line_bytes = 64;
inline fn sbox_lookup(sbox: *align(64) const [256]u8, idx0: u8, idx1: u8, idx2: u8, idx3: u8) [4]u8 {
if (side_channels_mitigations == .none) {
return [4]u8{
sbox[idx0],
sbox[idx1],
sbox[idx2],
sbox[idx3],
};
} else {
const stride = switch (side_channels_mitigations) {
.none => unreachable,
.basic => sbox.len / 4,
.medium => sbox.len / (sbox.len / cache_line_bytes) * 2,
.full => sbox.len / (sbox.len / cache_line_bytes),
};
const of0 = idx0 % stride;
const of1 = idx1 % stride;
const of2 = idx2 % stride;
const of3 = idx3 % stride;
var t: [4][sbox.len / stride]u8 align(64) = undefined;
var i: usize = 0;
while (i < t[0].len) : (i += 1) {
const tx = sbox[i * stride ..];
t[0][i] = tx[of0];
t[1][i] = tx[of1];
t[2][i] = tx[of2];
t[3][i] = tx[of3];
}
std.mem.doNotOptimizeAway(t);
return [4]u8{
t[0][idx0 / stride],
t[1][idx1 / stride],
t[2][idx2 / stride],
t[3][idx3 / stride],
};
}
}
inline fn table_lookup(table: *align(64) const [4][256]u32, idx0: u8, idx1: u8, idx2: u8, idx3: u8) [4]u32 {
if (side_channels_mitigations == .none) {
return [4]u32{
table[0][idx0],
table[1][idx1],
table[2][idx2],
table[3][idx3],
};
} else {
const table_bytes = @sizeOf(@TypeOf(table[0]));
const stride = switch (side_channels_mitigations) {
.none => unreachable,
.basic => table[0].len / 4,
.medium => table[0].len / (table_bytes / cache_line_bytes) * 2,
.full => table[0].len / (table_bytes / cache_line_bytes),
};
const of0 = idx0 % stride;
const of1 = idx1 % stride;
const of2 = idx2 % stride;
const of3 = idx3 % stride;
var t: [4][table[0].len / stride]u32 align(64) = undefined;
var i: usize = 0;
while (i < t[0].len) : (i += 1) {
const tx = table[0][i * stride ..];
t[0][i] = tx[of0];
t[1][i] = tx[of1];
t[2][i] = tx[of2];
t[3][i] = tx[of3];
}
std.mem.doNotOptimizeAway(t);
return [4]u32{
t[0][idx0 / stride],
math.rotl(u32, t[1][idx1 / stride], 8),
math.rotl(u32, t[2][idx2 / stride], 16),
math.rotl(u32, t[3][idx3 / stride], 24),
};
}
}