std.crypto: SHA-256 Properly gate comptime conditional

This feature detection must be done at comptime so that we avoid
generating invalid ASM for the target.
This commit is contained in:
Cody Tapscott 2022-10-24 00:38:10 -07:00
parent 10edb6d352
commit ee241c47ee

View File

@ -192,85 +192,89 @@ fn Sha2x32(comptime params: Sha2Params32) type {
s[i] |= @as(u32, b[i * 4 + 3]) << 0;
}
if (builtin.cpu.arch == .aarch64 and builtin.cpu.features.isEnabled(@enumToInt(std.Target.aarch64.Feature.sha2))) {
var x: v4u32 = d.s[0..4].*;
var y: v4u32 = d.s[4..8].*;
const s_v = @ptrCast(*[16]v4u32, &s);
switch (builtin.cpu.arch) {
.aarch64 => if (comptime builtin.cpu.features.isEnabled(@enumToInt(std.Target.aarch64.Feature.sha2))) {
var x: v4u32 = d.s[0..4].*;
var y: v4u32 = d.s[4..8].*;
const s_v = @ptrCast(*[16]v4u32, &s);
comptime var k: u8 = 0;
inline while (k < 16) : (k += 1) {
if (k > 3) {
s_v[k] = asm (
\\sha256su0.4s %[w0_3], %[w4_7]
\\sha256su1.4s %[w0_3], %[w8_11], %[w12_15]
: [w0_3] "=w" (-> v4u32),
: [_] "0" (s_v[k - 4]),
[w4_7] "w" (s_v[k - 3]),
[w8_11] "w" (s_v[k - 2]),
[w12_15] "w" (s_v[k - 1]),
comptime var k: u8 = 0;
inline while (k < 16) : (k += 1) {
if (k > 3) {
s_v[k] = asm (
\\sha256su0.4s %[w0_3], %[w4_7]
\\sha256su1.4s %[w0_3], %[w8_11], %[w12_15]
: [w0_3] "=w" (-> v4u32),
: [_] "0" (s_v[k - 4]),
[w4_7] "w" (s_v[k - 3]),
[w8_11] "w" (s_v[k - 2]),
[w12_15] "w" (s_v[k - 1]),
);
}
const w: v4u32 = s_v[k] +% @as(v4u32, W[4 * k ..][0..4].*);
asm volatile (
\\mov.4s v0, %[x]
\\sha256h.4s %[x], %[y], %[w]
\\sha256h2.4s %[y], v0, %[w]
: [x] "=w" (x),
[y] "=w" (y),
: [_] "0" (x),
[_] "1" (y),
[w] "w" (w),
: "v0"
);
}
const w: v4u32 = s_v[k] +% @as(v4u32, W[4 * k ..][0..4].*);
asm volatile (
\\mov.4s v0, %[x]
\\sha256h.4s %[x], %[y], %[w]
\\sha256h2.4s %[y], v0, %[w]
: [x] "=w" (x),
[y] "=w" (y),
: [_] "0" (x),
[_] "1" (y),
[w] "w" (w),
: "v0"
);
}
d.s[0..4].* = x +% @as(v4u32, d.s[0..4].*);
d.s[4..8].* = y +% @as(v4u32, d.s[4..8].*);
return;
},
.x86_64 => if (comptime builtin.cpu.features.isEnabled(@enumToInt(std.Target.x86.Feature.sha))) {
var x: v4u32 = [_]u32{ d.s[5], d.s[4], d.s[1], d.s[0] };
var y: v4u32 = [_]u32{ d.s[7], d.s[6], d.s[3], d.s[2] };
const s_v = @ptrCast(*[16]v4u32, &s);
d.s[0..4].* = x +% @as(v4u32, d.s[0..4].*);
d.s[4..8].* = y +% @as(v4u32, d.s[4..8].*);
return;
} else if (builtin.cpu.arch == .x86_64 and builtin.cpu.features.isEnabled(@enumToInt(std.Target.x86.Feature.sha))) {
var x: v4u32 = [_]u32{ d.s[5], d.s[4], d.s[1], d.s[0] };
var y: v4u32 = [_]u32{ d.s[7], d.s[6], d.s[3], d.s[2] };
const s_v = @ptrCast(*[16]v4u32, &s);
comptime var k: u8 = 0;
inline while (k < 16) : (k += 1) {
if (k < 12) {
const r = asm ("sha256msg1 %[w4_7], %[w0_3]"
: [w0_3] "=x" (-> v4u32),
: [_] "0" (s_v[k]),
[w4_7] "x" (s_v[k + 1]),
);
const t = @shuffle(u32, s_v[k + 2], s_v[k + 3], [_]i32{ 1, 2, 3, -1 });
s_v[k + 4] = asm ("sha256msg2 %[w12_15], %[t]"
: [t] "=x" (-> v4u32),
: [_] "0" (r +% t),
[w12_15] "x" (s_v[k + 3]),
);
}
comptime var k: u8 = 0;
inline while (k < 16) : (k += 1) {
if (k < 12) {
const r = asm ("sha256msg1 %[w4_7], %[w0_3]"
: [w0_3] "=x" (-> v4u32),
: [_] "0" (s_v[k]),
[w4_7] "x" (s_v[k + 1]),
);
const t = @shuffle(u32, s_v[k + 2], s_v[k + 3], [_]i32{ 1, 2, 3, -1 });
s_v[k + 4] = asm ("sha256msg2 %[w12_15], %[t]"
: [t] "=x" (-> v4u32),
: [_] "0" (r +% t),
[w12_15] "x" (s_v[k + 3]),
const w: v4u32 = s_v[k] +% @as(v4u32, W[4 * k ..][0..4].*);
asm volatile (
\\sha256rnds2 %[x], %[y]
\\pshufd $0xe, %%xmm0, %%xmm0
\\sha256rnds2 %[y], %[x]
: [y] "=x" (y),
[x] "=x" (x),
: [_] "0" (y),
[_] "1" (x),
[_] "{xmm0}" (w),
);
}
const w: v4u32 = s_v[k] +% @as(v4u32, W[4 * k ..][0..4].*);
asm volatile (
\\sha256rnds2 %[x], %[y]
\\pshufd $0xe, %%xmm0, %%xmm0
\\sha256rnds2 %[y], %[x]
: [y] "=x" (y),
[x] "=x" (x),
: [_] "0" (y),
[_] "1" (x),
[_] "{xmm0}" (w),
);
}
d.s[0] +%= x[3];
d.s[1] +%= x[2];
d.s[4] +%= x[1];
d.s[5] +%= x[0];
d.s[2] +%= y[3];
d.s[3] +%= y[2];
d.s[6] +%= y[1];
d.s[7] +%= y[0];
return;
d.s[0] +%= x[3];
d.s[1] +%= x[2];
d.s[4] +%= x[1];
d.s[5] +%= x[0];
d.s[2] +%= y[3];
d.s[3] +%= y[2];
d.s[6] +%= y[1];
d.s[7] +%= y[0];
return;
},
else => {},
}
while (i < 64) : (i += 1) {