simd: scalar base64 ignores invalid padding
parent
530633b9e9
commit
8477b6a21a
|
|
@ -1,10 +1,9 @@
|
|||
const std = @import("std");
|
||||
const options = @import("build_options");
|
||||
const assert = std.debug.assert;
|
||||
const log = std.log.scoped(.simd_base64);
|
||||
const scalar_decoder = @import("base64_scalar.zig").scalar_decoder;
|
||||
|
||||
// Used for the non-SIMD implementation
|
||||
const Base64Decoder = std.base64.standard_no_pad.Decoder;
|
||||
const log = std.log.scoped(.simd_base64);
|
||||
|
||||
pub fn maxLen(input: []const u8) usize {
|
||||
if (comptime options.simd) return ghostty_simd_base64_max_length(
|
||||
|
|
@ -16,7 +15,7 @@ pub fn maxLen(input: []const u8) usize {
|
|||
}
|
||||
|
||||
fn maxLenScalar(input: []const u8) usize {
|
||||
return Base64Decoder.calcSizeForSlice(scalarInput(input)) catch |err| {
|
||||
return scalar_decoder.calcSizeForSlice(scalarInput(input)) catch |err| {
|
||||
log.warn("failed to calculate base64 size for payload: {}", .{err});
|
||||
return 0;
|
||||
};
|
||||
|
|
@ -44,7 +43,7 @@ fn decodeScalar(
|
|||
const size = maxLenScalar(input);
|
||||
if (size == 0) return "";
|
||||
assert(output.len >= size);
|
||||
Base64Decoder.decode(
|
||||
scalar_decoder.decode(
|
||||
output,
|
||||
scalarInput(input),
|
||||
) catch return error.Base64Invalid;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,147 @@
|
|||
const std = @import("std");
|
||||
const assert = std.debug.assert;
|
||||
|
||||
pub const scalar_decoder: Base64Decoder = .init(
|
||||
std.base64.standard_alphabet_chars,
|
||||
null,
|
||||
);
|
||||
|
||||
/// Copied from Zig 0.14.1 stdlib and commented out the invalid padding
|
||||
/// scenarios, because Kitty Graphics requires a decoder that doesn't care
|
||||
/// about invalid padding scenarios.
|
||||
const Base64Decoder = struct {
|
||||
const invalid_char: u8 = 0xff;
|
||||
const invalid_char_tst: u32 = 0xff000000;
|
||||
|
||||
const Error = error{
|
||||
InvalidCharacter,
|
||||
InvalidPadding,
|
||||
NoSpaceLeft,
|
||||
};
|
||||
|
||||
/// e.g. 'A' => 0.
|
||||
/// `invalid_char` for any value not in the 64 alphabet chars.
|
||||
char_to_index: [256]u8,
|
||||
fast_char_to_index: [4][256]u32,
|
||||
pad_char: ?u8,
|
||||
|
||||
pub fn init(alphabet_chars: [64]u8, pad_char: ?u8) Base64Decoder {
|
||||
var result = Base64Decoder{
|
||||
.char_to_index = [_]u8{invalid_char} ** 256,
|
||||
.fast_char_to_index = .{[_]u32{invalid_char_tst} ** 256} ** 4,
|
||||
.pad_char = pad_char,
|
||||
};
|
||||
|
||||
var char_in_alphabet = [_]bool{false} ** 256;
|
||||
for (alphabet_chars, 0..) |c, i| {
|
||||
assert(!char_in_alphabet[c]);
|
||||
assert(pad_char == null or c != pad_char.?);
|
||||
|
||||
const ci = @as(u32, @intCast(i));
|
||||
result.fast_char_to_index[0][c] = ci << 2;
|
||||
result.fast_char_to_index[1][c] = (ci >> 4) | ((ci & 0x0f) << 12);
|
||||
result.fast_char_to_index[2][c] = ((ci & 0x3) << 22) | ((ci & 0x3c) << 6);
|
||||
result.fast_char_to_index[3][c] = ci << 16;
|
||||
|
||||
result.char_to_index[c] = @as(u8, @intCast(i));
|
||||
char_in_alphabet[c] = true;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Return the maximum possible decoded size for a given input length - The actual length may be less if the input includes padding.
|
||||
/// `InvalidPadding` is returned if the input length is not valid.
|
||||
pub fn calcSizeUpperBound(decoder: *const Base64Decoder, source_len: usize) Error!usize {
|
||||
var result = source_len / 4 * 3;
|
||||
const leftover = source_len % 4;
|
||||
if (decoder.pad_char != null) {
|
||||
if (leftover % 4 != 0) return error.InvalidPadding;
|
||||
} else {
|
||||
if (leftover % 4 == 1) return error.InvalidPadding;
|
||||
result += leftover * 3 / 4;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Return the exact decoded size for a slice.
|
||||
/// `InvalidPadding` is returned if the input length is not valid.
|
||||
pub fn calcSizeForSlice(decoder: *const Base64Decoder, source: []const u8) Error!usize {
|
||||
const source_len = source.len;
|
||||
var result = try decoder.calcSizeUpperBound(source_len);
|
||||
if (decoder.pad_char) |pad_char| {
|
||||
if (source_len >= 1 and source[source_len - 1] == pad_char) result -= 1;
|
||||
if (source_len >= 2 and source[source_len - 2] == pad_char) result -= 1;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/// dest.len must be what you get from ::calcSize.
|
||||
/// Invalid characters result in `error.InvalidCharacter`.
|
||||
/// Invalid padding results in `error.InvalidPadding`.
|
||||
pub fn decode(decoder: *const Base64Decoder, dest: []u8, source: []const u8) Error!void {
|
||||
if (decoder.pad_char != null and source.len % 4 != 0) return error.InvalidPadding;
|
||||
var dest_idx: usize = 0;
|
||||
var fast_src_idx: usize = 0;
|
||||
var acc: u12 = 0;
|
||||
var acc_len: u4 = 0;
|
||||
var leftover_idx: ?usize = null;
|
||||
while (fast_src_idx + 16 < source.len and dest_idx + 15 < dest.len) : ({
|
||||
fast_src_idx += 16;
|
||||
dest_idx += 12;
|
||||
}) {
|
||||
var bits: u128 = 0;
|
||||
inline for (0..4) |i| {
|
||||
var new_bits: u128 = decoder.fast_char_to_index[0][source[fast_src_idx + i * 4]];
|
||||
new_bits |= decoder.fast_char_to_index[1][source[fast_src_idx + 1 + i * 4]];
|
||||
new_bits |= decoder.fast_char_to_index[2][source[fast_src_idx + 2 + i * 4]];
|
||||
new_bits |= decoder.fast_char_to_index[3][source[fast_src_idx + 3 + i * 4]];
|
||||
if ((new_bits & invalid_char_tst) != 0) return error.InvalidCharacter;
|
||||
bits |= (new_bits << (24 * i));
|
||||
}
|
||||
std.mem.writeInt(u128, dest[dest_idx..][0..16], bits, .little);
|
||||
}
|
||||
while (fast_src_idx + 4 < source.len and dest_idx + 3 < dest.len) : ({
|
||||
fast_src_idx += 4;
|
||||
dest_idx += 3;
|
||||
}) {
|
||||
var bits = decoder.fast_char_to_index[0][source[fast_src_idx]];
|
||||
bits |= decoder.fast_char_to_index[1][source[fast_src_idx + 1]];
|
||||
bits |= decoder.fast_char_to_index[2][source[fast_src_idx + 2]];
|
||||
bits |= decoder.fast_char_to_index[3][source[fast_src_idx + 3]];
|
||||
if ((bits & invalid_char_tst) != 0) return error.InvalidCharacter;
|
||||
std.mem.writeInt(u32, dest[dest_idx..][0..4], bits, .little);
|
||||
}
|
||||
const remaining = source[fast_src_idx..];
|
||||
for (remaining, fast_src_idx..) |c, src_idx| {
|
||||
const d = decoder.char_to_index[c];
|
||||
if (d == invalid_char) {
|
||||
if (decoder.pad_char == null or c != decoder.pad_char.?) return error.InvalidCharacter;
|
||||
leftover_idx = src_idx;
|
||||
break;
|
||||
}
|
||||
acc = (acc << 6) + d;
|
||||
acc_len += 6;
|
||||
if (acc_len >= 8) {
|
||||
acc_len -= 8;
|
||||
dest[dest_idx] = @as(u8, @truncate(acc >> acc_len));
|
||||
dest_idx += 1;
|
||||
}
|
||||
}
|
||||
// if (acc_len > 4 or (acc & (@as(u12, 1) << acc_len) - 1) != 0) {
|
||||
// return error.InvalidPadding;
|
||||
// }
|
||||
if (leftover_idx == null) return;
|
||||
const leftover = source[leftover_idx.?..];
|
||||
if (decoder.pad_char) |pad_char| {
|
||||
const padding_len = acc_len / 2;
|
||||
var padding_chars: usize = 0;
|
||||
for (leftover) |c| {
|
||||
if (c != pad_char) {
|
||||
return if (c == Base64Decoder.invalid_char) error.InvalidCharacter else error.InvalidPadding;
|
||||
}
|
||||
padding_chars += 1;
|
||||
}
|
||||
if (padding_chars != padding_len) return error.InvalidPadding;
|
||||
}
|
||||
}
|
||||
};
|
||||
Loading…
Reference in New Issue