diff --git a/lib/std/base64.zig b/lib/std/base64.zig index 35a82dde01..a2f766dc26 100644 --- a/lib/std/base64.zig +++ b/lib/std/base64.zig @@ -102,14 +102,28 @@ pub const Base64Encoder = struct { var idx: usize = 0; var out_idx: usize = 0; - while (idx + 2 < source.len) : (idx += 3) { + while (idx + 15 < source.len) : (idx += 12) { + const bits = std.mem.readIntBig(u128, source[idx..][0..16]); + inline for (0..16) |i| { + dest[out_idx + i] = encoder.alphabet_chars[@truncate((bits >> (122 - i * 6)) & 0x3f)]; + } + out_idx += 16; + } + while (idx + 3 < source.len) : (idx += 3) { + const bits = std.mem.readIntBig(u32, source[idx..][0..4]); + dest[out_idx] = encoder.alphabet_chars[(bits >> 26) & 0x3f]; + dest[out_idx + 1] = encoder.alphabet_chars[(bits >> 20) & 0x3f]; + dest[out_idx + 2] = encoder.alphabet_chars[(bits >> 14) & 0x3f]; + dest[out_idx + 3] = encoder.alphabet_chars[(bits >> 8) & 0x3f]; + out_idx += 4; + } + if (idx + 2 < source.len) { dest[out_idx] = encoder.alphabet_chars[source[idx] >> 2]; dest[out_idx + 1] = encoder.alphabet_chars[((source[idx] & 0x3) << 4) | (source[idx + 1] >> 4)]; dest[out_idx + 2] = encoder.alphabet_chars[(source[idx + 1] & 0xf) << 2 | (source[idx + 2] >> 6)]; dest[out_idx + 3] = encoder.alphabet_chars[source[idx + 2] & 0x3f]; out_idx += 4; - } - if (idx + 1 < source.len) { + } else if (idx + 1 < source.len) { dest[out_idx] = encoder.alphabet_chars[source[idx] >> 2]; dest[out_idx + 1] = encoder.alphabet_chars[((source[idx] & 0x3) << 4) | (source[idx + 1] >> 4)]; dest[out_idx + 2] = encoder.alphabet_chars[(source[idx + 1] & 0xf) << 2]; @@ -130,15 +144,18 @@ pub const Base64Encoder = struct { pub const Base64Decoder = struct { const invalid_char: u8 = 0xff; + const invalid_char_tst: u32 = 0xff000000; /// e.g. 'A' => 0. /// `invalid_char` for any value not in the 64 alphabet chars. char_to_index: [256]u8, + fast_char_to_index: [4][256]u32, pad_char: ?u8, pub fn init(alphabet_chars: [64]u8, pad_char: ?u8) Base64Decoder { var result = Base64Decoder{ .char_to_index = [_]u8{invalid_char} ** 256, + .fast_char_to_index = .{[_]u32{invalid_char_tst} ** 256} ** 4, .pad_char = pad_char, }; @@ -147,6 +164,12 @@ pub const Base64Decoder = struct { assert(!char_in_alphabet[c]); assert(pad_char == null or c != pad_char.?); + const ci = @as(u32, @intCast(i)); + result.fast_char_to_index[0][c] = ci << 2; + result.fast_char_to_index[1][c] = (ci >> 4) | ((ci & 0x0f) << 12); + result.fast_char_to_index[2][c] = ((ci & 0x3) << 22) | ((ci & 0x3c) << 6); + result.fast_char_to_index[3][c] = ci << 16; + result.char_to_index[c] = @as(u8, @intCast(i)); char_in_alphabet[c] = true; } @@ -184,11 +207,39 @@ pub const Base64Decoder = struct { /// invalid padding results in error.InvalidPadding. pub fn decode(decoder: *const Base64Decoder, dest: []u8, source: []const u8) Error!void { if (decoder.pad_char != null and source.len % 4 != 0) return error.InvalidPadding; + var dest_idx: usize = 0; + var fast_src_idx: usize = 0; var acc: u12 = 0; var acc_len: u4 = 0; - var dest_idx: usize = 0; var leftover_idx: ?usize = null; - for (source, 0..) |c, src_idx| { + while (fast_src_idx + 16 < source.len and dest_idx + 15 < dest.len) : ({ + fast_src_idx += 16; + dest_idx += 12; + }) { + var bits: u128 = 0; + inline for (0..4) |i| { + var new_bits: u128 = decoder.fast_char_to_index[0][source[fast_src_idx + i * 4]]; + new_bits |= decoder.fast_char_to_index[1][source[fast_src_idx + 1 + i * 4]]; + new_bits |= decoder.fast_char_to_index[2][source[fast_src_idx + 2 + i * 4]]; + new_bits |= decoder.fast_char_to_index[3][source[fast_src_idx + 3 + i * 4]]; + if ((new_bits & invalid_char_tst) != 0) return error.InvalidCharacter; + bits |= (new_bits << (24 * i)); + } + std.mem.writeIntLittle(u128, dest[dest_idx..][0..16], bits); + } + while (fast_src_idx + 4 < source.len and dest_idx + 3 < dest.len) : ({ + fast_src_idx += 4; + dest_idx += 3; + }) { + var bits = decoder.fast_char_to_index[0][source[fast_src_idx]]; + bits |= decoder.fast_char_to_index[1][source[fast_src_idx + 1]]; + bits |= decoder.fast_char_to_index[2][source[fast_src_idx + 2]]; + bits |= decoder.fast_char_to_index[3][source[fast_src_idx + 3]]; + if ((bits & invalid_char_tst) != 0) return error.InvalidCharacter; + std.mem.writeIntLittle(u32, dest[dest_idx..][0..4], bits); + } + var remaining = source[fast_src_idx..]; + for (remaining, fast_src_idx..) |c, src_idx| { const d = decoder.char_to_index[c]; if (d == invalid_char) { if (decoder.pad_char == null or c != decoder.pad_char.?) return error.InvalidCharacter; @@ -338,6 +389,10 @@ fn testBase64() !void { try testAllApis(codecs, "foob", "Zm9vYg=="); try testAllApis(codecs, "fooba", "Zm9vYmE="); try testAllApis(codecs, "foobar", "Zm9vYmFy"); + try testAllApis(codecs, "foobarfoobarfoo", "Zm9vYmFyZm9vYmFyZm9v"); + try testAllApis(codecs, "foobarfoobarfoob", "Zm9vYmFyZm9vYmFyZm9vYg=="); + try testAllApis(codecs, "foobarfoobarfooba", "Zm9vYmFyZm9vYmFyZm9vYmE="); + try testAllApis(codecs, "foobarfoobarfoobar", "Zm9vYmFyZm9vYmFyZm9vYmFy"); try testDecodeIgnoreSpace(codecs, "", " "); try testDecodeIgnoreSpace(codecs, "f", "Z g= ="); @@ -357,11 +412,23 @@ fn testBase64() !void { try testError(codecs, "A/==", error.InvalidPadding); try testError(codecs, "A===", error.InvalidPadding); try testError(codecs, "====", error.InvalidPadding); + try testError(codecs, "Zm9vYmFyZm9vYmFyA..A", error.InvalidCharacter); + try testError(codecs, "Zm9vYmFyZm9vYmFyAA=A", error.InvalidPadding); + try testError(codecs, "Zm9vYmFyZm9vYmFyAA/=", error.InvalidPadding); + try testError(codecs, "Zm9vYmFyZm9vYmFyA/==", error.InvalidPadding); + try testError(codecs, "Zm9vYmFyZm9vYmFyA===", error.InvalidPadding); + try testError(codecs, "A..AZm9vYmFyZm9vYmFy", error.InvalidCharacter); + try testError(codecs, "Zm9vYmFyZm9vAA=A", error.InvalidPadding); + try testError(codecs, "Zm9vYmFyZm9vAA/=", error.InvalidPadding); + try testError(codecs, "Zm9vYmFyZm9vA/==", error.InvalidPadding); + try testError(codecs, "Zm9vYmFyZm9vA===", error.InvalidPadding); try testNoSpaceLeftError(codecs, "AA=="); try testNoSpaceLeftError(codecs, "AAA="); try testNoSpaceLeftError(codecs, "AAAA"); try testNoSpaceLeftError(codecs, "AAAAAA=="); + + try testFourBytesDestNoSpaceLeftError(codecs, "AAAAAAAAAAAAAAAA"); } fn testBase64UrlSafeNoPad() !void { @@ -374,6 +441,7 @@ fn testBase64UrlSafeNoPad() !void { try testAllApis(codecs, "foob", "Zm9vYg"); try testAllApis(codecs, "fooba", "Zm9vYmE"); try testAllApis(codecs, "foobar", "Zm9vYmFy"); + try testAllApis(codecs, "foobarfoobarfoobar", "Zm9vYmFyZm9vYmFyZm9vYmFy"); try testDecodeIgnoreSpace(codecs, "", " "); try testDecodeIgnoreSpace(codecs, "f", "Z g "); @@ -392,11 +460,15 @@ fn testBase64UrlSafeNoPad() !void { try testError(codecs, "A/==", error.InvalidCharacter); try testError(codecs, "A===", error.InvalidCharacter); try testError(codecs, "====", error.InvalidCharacter); + try testError(codecs, "Zm9vYmFyZm9vYmFyA..A", error.InvalidCharacter); + try testError(codecs, "A..AZm9vYmFyZm9vYmFy", error.InvalidCharacter); try testNoSpaceLeftError(codecs, "AA"); try testNoSpaceLeftError(codecs, "AAA"); try testNoSpaceLeftError(codecs, "AAAA"); try testNoSpaceLeftError(codecs, "AAAAAA"); + + try testFourBytesDestNoSpaceLeftError(codecs, "AAAAAAAAAAAAAAAA"); } fn testAllApis(codecs: Codecs, expected_decoded: []const u8, expected_encoded: []const u8) !void { @@ -457,3 +529,12 @@ fn testNoSpaceLeftError(codecs: Codecs, encoded: []const u8) !void { return error.ExpectedError; } else |err| if (err != error.NoSpaceLeft) return err; } + +fn testFourBytesDestNoSpaceLeftError(codecs: Codecs, encoded: []const u8) !void { + const decoder_ignore_space = codecs.decoderWithIgnore(" "); + var buffer: [0x100]u8 = undefined; + var decoded = buffer[0..4]; + if (decoder_ignore_space.decode(decoded, encoded)) |_| { + return error.ExpectedError; + } else |err| if (err != error.NoSpaceLeft) return err; +}