crypto.tls: support rsa_pss_rsae_sha256 and fixes

* fix eof logic
 * fix read logic
 * fix VecPut logic
 * add some debug prints to remove later
This commit is contained in:
Andrew Kelley 2022-12-29 17:56:46 -07:00
parent e4a9b19a14
commit 22e2aaa283
2 changed files with 239 additions and 37 deletions

View File

@ -474,19 +474,9 @@ fn verifyRsa(
pub_key: []const u8,
) !void {
if (pub_key_algo != .rsaEncryption) return error.CertificateSignatureAlgorithmMismatch;
const pub_key_seq = try der.Element.parse(pub_key, 0);
if (pub_key_seq.identifier.tag != .sequence) return error.CertificateFieldHasWrongDataType;
const modulus_elem = try der.Element.parse(pub_key, pub_key_seq.slice.start);
if (modulus_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType;
const exponent_elem = try der.Element.parse(pub_key, modulus_elem.slice.end);
if (exponent_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType;
// Skip over meaningless zeroes in the modulus.
const modulus_raw = pub_key[modulus_elem.slice.start..modulus_elem.slice.end];
const modulus_offset = for (modulus_raw) |byte, i| {
if (byte != 0) break i;
} else modulus_raw.len;
const modulus = modulus_raw[modulus_offset..];
const exponent = pub_key[exponent_elem.slice.start..exponent_elem.slice.end];
const pk_components = try rsa.PublicKey.parseDer(pub_key);
const exponent = pk_components.exponent;
const modulus = pk_components.modulus;
if (exponent.len > modulus.len) return error.CertificatePublicKeyInvalid;
if (sig.len != modulus.len) return error.CertificateSignatureInvalidLength;
@ -688,10 +678,154 @@ test {
/// which is licensed under the Apache License Version 2.0, January 2004
/// http://www.apache.org/licenses/
/// The code has been modified.
const rsa = struct {
pub const rsa = struct {
const BigInt = std.math.big.int.Managed;
const PublicKey = struct {
pub const PSSSignature = struct {
pub fn fromBytes(comptime modulus_len: usize, msg: []const u8) [modulus_len]u8 {
var result = [1]u8{0} ** modulus_len;
std.mem.copy(u8, &result, msg);
return result;
}
pub fn verify(comptime modulus_len: usize, sig: [modulus_len]u8, msg: []const u8, public_key: PublicKey, comptime Hash: type, allocator: std.mem.Allocator) !void {
const mod_bits = try countBits(public_key.n.toConst(), allocator);
const em_dec = try encrypt(modulus_len, sig, public_key, allocator);
try EMSA_PSS_VERIFY(msg, &em_dec, mod_bits - 1, Hash.digest_length, Hash, allocator);
}
fn EMSA_PSS_VERIFY(msg: []const u8, em: []const u8, emBit: usize, sLen: usize, comptime Hash: type, allocator: std.mem.Allocator) !void {
// TODO
// 1. If the length of M is greater than the input limitation for
// the hash function (2^61 - 1 octets for SHA-1), output
// "inconsistent" and stop.
// emLen = \ceil(emBits/8)
const emLen = ((emBit - 1) / 8) + 1;
std.debug.assert(emLen == em.len);
// 2. Let mHash = Hash(M), an octet string of length hLen.
var mHash: [Hash.digest_length]u8 = undefined;
Hash.hash(msg, &mHash, .{});
// 3. If emLen < hLen + sLen + 2, output "inconsistent" and stop.
if (emLen < Hash.digest_length + sLen + 2) {
return error.InvalidSignature;
}
// 4. If the rightmost octet of EM does not have hexadecimal value
// 0xbc, output "inconsistent" and stop.
if (em[em.len - 1] != 0xbc) {
return error.InvalidSignature;
}
// 5. Let maskedDB be the leftmost emLen - hLen - 1 octets of EM,
// and let H be the next hLen octets.
const maskedDB = em[0..(emLen - Hash.digest_length - 1)];
const h = em[(emLen - Hash.digest_length - 1)..(emLen - 1)];
// 6. If the leftmost 8emLen - emBits bits of the leftmost octet in
// maskedDB are not all equal to zero, output "inconsistent" and
// stop.
const zero_bits = emLen * 8 - emBit;
var mask: u8 = maskedDB[0];
var i: usize = 0;
while (i < 8 - zero_bits) : (i += 1) {
mask = mask >> 1;
}
if (mask != 0) {
return error.InvalidSignature;
}
// 7. Let dbMask = MGF(H, emLen - hLen - 1).
const mgf_len = emLen - Hash.digest_length - 1;
var mgf_out = try allocator.alloc(u8, ((mgf_len - 1) / Hash.digest_length + 1) * Hash.digest_length);
defer allocator.free(mgf_out);
var dbMask = try MGF1(mgf_out, h, mgf_len, Hash, allocator);
// 8. Let DB = maskedDB \xor dbMask.
i = 0;
while (i < dbMask.len) : (i += 1) {
dbMask[i] = maskedDB[i] ^ dbMask[i];
}
// 9. Set the leftmost 8emLen - emBits bits of the leftmost octet
// in DB to zero.
i = 0;
mask = 0;
while (i < 8 - zero_bits) : (i += 1) {
mask = mask << 1;
mask += 1;
}
dbMask[0] = dbMask[0] & mask;
// 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not
// zero or if the octet at position emLen - hLen - sLen - 1 (the
// leftmost position is "position 1") does not have hexadecimal
// value 0x01, output "inconsistent" and stop.
if (dbMask[mgf_len - sLen - 2] != 0x00) {
return error.InvalidSignature;
}
if (dbMask[mgf_len - sLen - 1] != 0x01) {
return error.InvalidSignature;
}
// 11. Let salt be the last sLen octets of DB.
const salt = dbMask[(mgf_len - sLen)..];
// 12. Let
// M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ;
// M' is an octet string of length 8 + hLen + sLen with eight
// initial zero octets.
var m_p = try allocator.alloc(u8, 8 + Hash.digest_length + sLen);
defer allocator.free(m_p);
std.mem.copy(u8, m_p, &([_]u8{0} ** 8));
std.mem.copy(u8, m_p[8..], &mHash);
std.mem.copy(u8, m_p[(8 + Hash.digest_length)..], salt);
// 13. Let H' = Hash(M'), an octet string of length hLen.
var h_p: [Hash.digest_length]u8 = undefined;
Hash.hash(m_p, &h_p, .{});
// 14. If H = H', output "consistent". Otherwise, output
// "inconsistent".
if (!std.mem.eql(u8, h, &h_p)) {
return error.InvalidSignature;
}
}
fn MGF1(out: []u8, seed: []const u8, len: usize, comptime Hash: type, allocator: std.mem.Allocator) ![]u8 {
var counter: usize = 0;
var idx: usize = 0;
var c: [4]u8 = undefined;
var hash = try allocator.alloc(u8, seed.len + c.len);
defer allocator.free(hash);
std.mem.copy(u8, hash, seed);
var hashed: [Hash.digest_length]u8 = undefined;
while (idx < len) {
c[0] = @intCast(u8, (counter >> 24) & 0xFF);
c[1] = @intCast(u8, (counter >> 16) & 0xFF);
c[2] = @intCast(u8, (counter >> 8) & 0xFF);
c[3] = @intCast(u8, counter & 0xFF);
std.mem.copy(u8, hash[seed.len..], &c);
Hash.hash(hash, &hashed, .{});
std.mem.copy(u8, out[idx..], &hashed);
idx += hashed.len;
counter += 1;
}
return out[0..len];
}
};
pub const PublicKey = struct {
n: BigInt,
e: BigInt,
@ -714,6 +848,24 @@ const rsa = struct {
.e = _e,
};
}
pub fn parseDer(pub_key: []const u8) !struct { modulus: []const u8, exponent: []const u8 } {
const pub_key_seq = try der.Element.parse(pub_key, 0);
if (pub_key_seq.identifier.tag != .sequence) return error.CertificateFieldHasWrongDataType;
const modulus_elem = try der.Element.parse(pub_key, pub_key_seq.slice.start);
if (modulus_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType;
const exponent_elem = try der.Element.parse(pub_key, modulus_elem.slice.end);
if (exponent_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType;
// Skip over meaningless zeroes in the modulus.
const modulus_raw = pub_key[modulus_elem.slice.start..modulus_elem.slice.end];
const modulus_offset = for (modulus_raw) |byte, i| {
if (byte != 0) break i;
} else modulus_raw.len;
return .{
.modulus = modulus_raw[modulus_offset..],
.exponent = pub_key[exponent_elem.slice.start..exponent_elem.slice.end],
};
}
};
fn encrypt(comptime modulus_len: usize, msg: [modulus_len]u8, public_key: PublicKey, allocator: std.mem.Allocator) ![modulus_len]u8 {
@ -812,6 +964,20 @@ const rsa = struct {
try BigInt.divFloor(&q, rem, a, n);
}
fn countBits(a: std.math.big.int.Const, allocator: std.mem.Allocator) !usize {
var i: usize = 0;
var a_copy = try BigInt.init(allocator);
defer a_copy.deinit();
try a_copy.copy(a);
while (!a_copy.eqZero()) {
try a_copy.shiftRight(&a_copy, 1);
i += 1;
}
return i;
}
// TODO: flush the toilet
const poop = std.heap.page_allocator;
pub const poop = std.heap.page_allocator;
};

View File

@ -536,7 +536,24 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
try sig.verify(verify_bytes, key);
},
.rsa_pss_rsae_sha256 => {
@panic("TODO signature scheme: rsa_pss_rsae_sha256");
if (main_cert_pub_key_algo != .rsaEncryption)
return error.TlsBadSignatureScheme;
const Hash = crypto.hash.sha2.Sha256;
const rsa = Certificate.rsa;
const components = try rsa.PublicKey.parseDer(main_cert_pub_key);
const exponent = components.exponent;
const modulus = components.modulus;
switch (modulus.len) {
inline 128, 256, 512 => |modulus_len| {
const key = try rsa.PublicKey.fromBytes(exponent, modulus, rsa.poop);
const sig = rsa.PSSSignature.fromBytes(modulus_len, encoded_sig);
try rsa.PSSSignature.verify(modulus_len, sig, verify_bytes, key, Hash, rsa.poop);
},
else => {
return error.TlsBadRsaSignatureBitCount;
},
}
},
else => {
//std.debug.print("signature scheme: {any}\n", .{
@ -737,7 +754,7 @@ pub fn writeAll(c: *Client, stream: net.Stream, bytes: []const u8) !void {
}
pub fn eof(c: Client) bool {
return c.received_close_notify and c.partial_ciphertext_end == 0;
return c.received_close_notify and c.partial_ciphertext_idx >= c.partial_ciphertext_end;
}
/// Returns the number of bytes read, calling the underlying read function the
@ -822,6 +839,10 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove
c.partial_cleartext_idx = 0;
c.partial_ciphertext_idx = 0;
c.partial_ciphertext_end = 0;
} else {
std.debug.print("finished giving partial cleartext. {d} bytes ciphertext remain\n", .{
c.partial_ciphertext_end - c.partial_ciphertext_idx,
});
}
}
@ -866,8 +887,9 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove
// There might be more bytes inside `in_stack_buffer` that need to be processed,
// but at least frag0 will have one complete ciphertext record.
const frag0 = c.partially_read_buffer[0..@min(c.partially_read_buffer.len, actual_read_len)];
var frag1 = in_stack_buffer[0 .. actual_read_len - frag0.len];
const frag0_end = @min(c.partially_read_buffer.len, c.partial_ciphertext_end + actual_read_len);
const frag0 = c.partially_read_buffer[c.partial_ciphertext_idx..frag0_end];
var frag1 = in_stack_buffer[0..actual_read_len -| first_iov.len];
// We need to decipher frag0 and frag1 but there may be a ciphertext record
// straddling the boundary. We can handle this with two memcpy() calls to
// assemble the straddling record in between handling the two sides.
@ -900,12 +922,14 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove
const record_len = (record_len_byte_0 << 8) | record_len_byte_1;
if (record_len > max_ciphertext_len) return error.TlsRecordOverflow;
const second_len = record_len + tls.ciphertext_record_header_len - first.len;
const full_record_len = record_len + tls.ciphertext_record_header_len;
const second_len = full_record_len - first.len;
if (frag1.len < second_len)
return finishRead2(c, first, frag1, vp.total);
mem.copy(u8, frag[0..in], first);
mem.copy(u8, frag[first.len..], frag1[0..second_len]);
frag = frag[0..full_record_len];
frag1 = frag1[second_len..];
in = 0;
continue;
@ -914,23 +938,35 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove
in += 1;
const legacy_version = mem.readIntBig(u16, frag[in..][0..2]);
in += 2;
_ = legacy_version;
//_ = legacy_version;
const record_len = mem.readIntBig(u16, frag[in..][0..2]);
std.debug.print("ct={any} legacy_version={x} record_len={d}\n", .{
ct, legacy_version, record_len,
});
if (record_len > max_ciphertext_len) return error.TlsRecordOverflow;
in += 2;
const end = in + record_len;
if (end > frag.len) {
// We need the record header on the next iteration of the loop.
in -= tls.ciphertext_record_header_len;
if (frag.ptr == frag1.ptr)
return finishRead(c, frag, in, vp.total);
// A record straddles the two fragments. Copy into the now-empty first fragment.
const first = frag[in..];
const second_len = record_len + tls.ciphertext_record_header_len - first.len;
if (frag1.len < second_len)
const full_record_len = record_len + tls.ciphertext_record_header_len;
const second_len = full_record_len - first.len;
if (frag1.len < second_len) {
std.debug.print("end > frag.len finishRead2 end={d} frag.len={d}\n", .{
end, frag.len,
});
return finishRead2(c, first, frag1, vp.total);
}
mem.copy(u8, frag[0..in], first);
mem.copy(u8, frag[first.len..], frag1[0..second_len]);
frag = frag[0..full_record_len];
frag1 = frag1[second_len..];
in = 0;
continue;
@ -991,9 +1027,11 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove
const handshake = cleartext[ct_i..next_handshake_i];
switch (handshake_type) {
.new_session_ticket => {
std.debug.print("new_session_ticket\n", .{});
// This client implementation ignores new session tickets.
},
.key_update => {
std.debug.print("key_update\n", .{});
switch (c.application_cipher) {
inline else => |*p| {
const P = @TypeOf(p.*);
@ -1042,10 +1080,13 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove
const dest = c.partially_read_buffer[c.partial_ciphertext_idx..];
mem.copy(u8, dest, msg);
c.partial_ciphertext_idx = @intCast(@TypeOf(c.partial_ciphertext_idx), c.partial_ciphertext_idx + msg.len);
std.debug.print("application_data {d} bytes to partial buffer\n", .{msg.len});
} else {
const amt = vp.put(msg);
std.debug.print("application_data {d} bytes to read buffer\n", .{msg.len});
if (amt < msg.len) {
const rest = msg[amt..];
std.debug.print(" {d} bytes to partial buffer\n", .{rest.len});
c.partial_cleartext_idx = 0;
c.partial_ciphertext_idx = @intCast(@TypeOf(c.partial_ciphertext_idx), rest.len);
mem.copy(u8, &c.partially_read_buffer, rest);
@ -1055,6 +1096,7 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove
// Output buffer was used directly which means no
// memory copying needs to occur, and we can move
// on to the next ciphertext record.
std.debug.print("application_data {d} bytes directly to read buffer\n", .{cleartext.len - 1});
vp.next(cleartext.len - 1);
}
},
@ -1166,10 +1208,6 @@ const VecPut = struct {
const src = bytes[bytes_i..][0..@min(dest.len, bytes.len - bytes_i)];
mem.copy(u8, dest, src);
bytes_i += src.len;
if (bytes_i >= bytes.len) {
vp.total += bytes_i;
return bytes_i;
}
vp.off += src.len;
if (vp.off >= v.iov_len) {
vp.off = 0;
@ -1179,6 +1217,10 @@ const VecPut = struct {
return bytes_i;
}
}
if (bytes_i >= bytes.len) {
vp.total += bytes_i;
return bytes_i;
}
}
}
@ -1201,17 +1243,11 @@ const VecPut = struct {
}
fn freeSize(vp: VecPut) usize {
if (vp.idx >= vp.iovecs.len) return 0;
var total: usize = 0;
total += vp.iovecs[vp.idx].iov_len - vp.off;
if (vp.idx + 1 >= vp.iovecs.len)
return total;
for (vp.iovecs[vp.idx + 1 ..]) |v| {
total += v.iov_len;
}
if (vp.idx + 1 >= vp.iovecs.len) return total;
for (vp.iovecs[vp.idx + 1 ..]) |v| total += v.iov_len;
return total;
}
};