diff --git a/tokenizers/Cargo.toml b/tokenizers/Cargo.toml index d2841313eb..d174402f2a 100644 --- a/tokenizers/Cargo.toml +++ b/tokenizers/Cargo.toml @@ -66,6 +66,10 @@ harness = false name = "ci_benchmark" harness = false +[[bench]] +name = "ascii_lower_benchmark" +harness = false + [dependencies] rand = "0.9" onig = { version = "6.5.1", default-features = false, optional = true } diff --git a/tokenizers/benches/ascii_lower_benchmark.rs b/tokenizers/benches/ascii_lower_benchmark.rs new file mode 100644 index 0000000000..1b3a73b825 --- /dev/null +++ b/tokenizers/benches/ascii_lower_benchmark.rs @@ -0,0 +1,77 @@ +//! Microbenchmark for the ASCII lowercase fast path used by the `Lowercase` +//! normalizer. Compares the SIMD-dispatched `utils::simd::ascii_lower` against +//! the scalar reference and against the previous Unicode-aware path +//! (`char::to_lowercase` per char) on representative buffer sizes. + +#[macro_use] +extern crate criterion; + +use criterion::{Criterion, Throughput}; +use std::hint::black_box; +use tokenizers::utils::simd::ascii_lower; + +fn make_buffer(len: usize) -> Vec { + // Cycle through printable ASCII (mix of upper, lower, digits, punctuation) + // so the upper-case branch fires on roughly 26/95 of bytes. + (0..len).map(|i| 0x20u8 + (i as u8 % 0x5F)).collect() +} + +fn scalar_lower(buf: &mut [u8]) { + for b in buf { + if b.is_ascii_uppercase() { + *b |= 0x20; + } + } +} + +fn unicode_lower(buf: &str) -> String { + // Mirrors what `NormalizedString::lowercase` did before the fast path: per + // `char` UTF-8 decode + Unicode case folding, ignoring alignment bookkeeping. + let mut out = String::with_capacity(buf.len()); + for c in buf.chars() { + for lc in c.to_lowercase() { + out.push(lc); + } + } + out +} + +pub fn bench_ascii_lower(c: &mut Criterion) { + for &len in &[64usize, 1024, 16 * 1024, 256 * 1024] { + let mut group = c.benchmark_group(format!("ascii_lower/{len}B")); + group.throughput(Throughput::Bytes(len as u64)); + + let original = make_buffer(len); + + group.bench_function("simd", |b| { + let mut buf = original.clone(); + b.iter(|| { + ascii_lower(black_box(&mut buf)); + }); + }); + + group.bench_function("scalar", |b| { + let mut buf = original.clone(); + b.iter(|| { + scalar_lower(black_box(&mut buf)); + }); + }); + + // Stand-in for the pre-SIMD code path. + let s = String::from_utf8(original.clone()).unwrap(); + group.bench_function("unicode_chars", |b| { + b.iter(|| { + black_box(unicode_lower(black_box(&s))); + }); + }); + + group.finish(); + } +} + +criterion_group! { + name = ascii_lower_bench; + config = Criterion::default().sample_size(50); + targets = bench_ascii_lower +} +criterion_main!(ascii_lower_bench); diff --git a/tokenizers/src/tokenizer/normalizer.rs b/tokenizers/src/tokenizer/normalizer.rs index 5bebd5f7b4..e72656c684 100644 --- a/tokenizers/src/tokenizer/normalizer.rs +++ b/tokenizers/src/tokenizer/normalizer.rs @@ -544,6 +544,16 @@ impl NormalizedString { /// Lowercase pub fn lowercase(&mut self) -> &mut Self { + // ASCII fast path: each `A`..=`Z` becomes a single-byte `a`..=`z`, + // so byte length and per-byte alignments are unchanged. We can mutate + // bytes in place and skip the Unicode-aware `transform` rebuild. + if self.normalized.is_ascii() { + // Safety: `ascii_lower` only flips `0x20` on bytes already in + // `b'A'..=b'Z'` (all < 0x80), so the result remains valid UTF-8. + let bytes = unsafe { self.normalized.as_bytes_mut() }; + crate::utils::simd::ascii_lower(bytes); + return self; + } let mut new_chars: Vec<(char, isize)> = vec![]; self.for_each(|c| { c.to_lowercase().enumerate().for_each(|(index, c)| { @@ -2289,6 +2299,45 @@ mod tests { assert_eq!(s.get(), "a..."); } + #[test] + fn lowercase_ascii_fast_path_preserves_alignments() { + // After a non-trivial transform (here NFKD on a ligature) the alignments + // map several normalized bytes back onto fewer original bytes. The ASCII + // fast path must leave that mapping byte-for-byte unchanged. + let mut n = NormalizedString::from("ABC\u{FB00}DEF"); // "ABCffDEF"; ff -> "ff" via NFKD + n.nfkd(); + // Sanity: result is now all ASCII so the fast path will trigger. + assert!(n.get().is_ascii()); + + let bytes_before = n.normalized.clone(); + let alignments_before = n.alignments.clone(); + let original_before = n.original.clone(); + let shift_before = n.original_shift; + + n.lowercase(); + + assert_eq!( + n.normalized, + bytes_before.to_lowercase(), + "bytes mismatch fast vs ASCII to_lowercase" + ); + assert_eq!(n.alignments, alignments_before, "alignments mutated"); + assert_eq!(n.original, original_before, "original mutated"); + assert_eq!(n.original_shift, shift_before, "original_shift mutated"); + } + + #[test] + fn lowercase_ascii_matches_unicode_path_byte_for_byte() { + // Cross-check against char::to_lowercase on every printable ASCII byte: + // the fast path must produce exactly the same bytes the slow path would. + let input: String = (0x20u8..0x7F).map(|b| b as char).collect(); + let mut fast = NormalizedString::from(input.as_str()); + fast.lowercase(); + + let expected: String = input.chars().flat_map(|c| c.to_lowercase()).collect(); + assert_eq!(fast.get(), expected); + } + #[test] fn test_append_after_clear() { let mut n = NormalizedString::from("Hello"); diff --git a/tokenizers/src/utils/mod.rs b/tokenizers/src/utils/mod.rs index c9450b3222..f263fbd0c0 100644 --- a/tokenizers/src/utils/mod.rs +++ b/tokenizers/src/utils/mod.rs @@ -18,6 +18,7 @@ pub mod iter; pub mod padding; pub mod parallelism; pub(crate) mod progress; +pub mod simd; pub mod truncation; // Re-export ProgressFormat for public API diff --git a/tokenizers/src/utils/simd.rs b/tokenizers/src/utils/simd.rs new file mode 100644 index 0000000000..5bea10d77d --- /dev/null +++ b/tokenizers/src/utils/simd.rs @@ -0,0 +1,227 @@ +//! SIMD helpers for ASCII fast paths. +//! +//! Each public function dispatches at runtime (via `is_x86_feature_detected!` +//! on x86_64; aarch64 always has NEON under the stable target ABI) and falls +//! back to a scalar implementation on other architectures. + +/// Lowercase ASCII letters (`A`..=`Z` → `a`..=`z`) in place. Bytes outside that +/// range are left untouched, so it is safe to call on any byte slice — but for +/// best speed it should be guarded by a `is_ascii()` check upstream so the +/// caller can also skip Unicode-aware logic. +#[inline] +pub fn ascii_lower(buf: &mut [u8]) { + #[cfg(target_arch = "x86_64")] + { + // Only enable the AVX-512 path on CPUs where the AVX-512 license-mode + // downclock is negligible. `avx512fp16` is used as a proxy: it is + // present on Intel Sapphire Rapids+ and on AMD Zen 4/Zen 5, and is + // absent on Skylake-X / Cascade Lake / Cooper Lake / Ice Lake-SP / + // Rocket Lake — i.e. exactly the generations where 512-bit ops trigger + // significant frequency throttling. + if std::is_x86_feature_detected!("avx512f") + && std::is_x86_feature_detected!("avx512bw") + && std::is_x86_feature_detected!("avx512fp16") + { + unsafe { return ascii_lower_avx512(buf) }; + } + if std::is_x86_feature_detected!("avx2") { + unsafe { return ascii_lower_avx2(buf) }; + } + // SSE2 is part of the x86_64 baseline; always available. + unsafe { return ascii_lower_sse2(buf) }; + } + #[cfg(target_arch = "aarch64")] + { + // NEON is mandatory on the stable aarch64 ABI; no runtime check needed. + unsafe { return ascii_lower_neon(buf) }; + } + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + { + ascii_lower_scalar(buf); + } +} + +#[inline(always)] +fn ascii_lower_scalar(buf: &mut [u8]) { + for b in buf { + if b.is_ascii_uppercase() { + *b |= 0x20; + } + } +} + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx512f,avx512bw")] +unsafe fn ascii_lower_avx512(buf: &mut [u8]) { + use std::arch::x86_64::*; + + let a_minus_1 = _mm512_set1_epi8(b'A' as i8 - 1); + let z_plus_1 = _mm512_set1_epi8(b'Z' as i8 + 1); + let case_bit = _mm512_set1_epi8(0x20); + + let len = buf.len(); + let mut i = 0; + while i + 64 <= len { + let p = buf.as_mut_ptr().add(i) as *mut __m512i; + let v = _mm512_loadu_si512(p as *const __m512i); + // `__mmask64` is `u64` in Rust; the two range checks AND together as a + // plain bitwise op. Signed compares are correct here for the same + // reason as the SSE2/AVX2 paths: ASCII bytes are < 0x80. + let gt_a = _mm512_cmpgt_epi8_mask(v, a_minus_1); + let lt_z = _mm512_cmpgt_epi8_mask(z_plus_1, v); + let mask = gt_a & lt_z; + let mask_vec = _mm512_movm_epi8(mask); + let flip = _mm512_and_si512(mask_vec, case_bit); + let out = _mm512_xor_si512(v, flip); + _mm512_storeu_si512(p, out); + i += 64; + } + ascii_lower_scalar(&mut buf[i..]); +} + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +unsafe fn ascii_lower_avx2(buf: &mut [u8]) { + use std::arch::x86_64::*; + + let a_minus_1 = _mm256_set1_epi8(b'A' as i8 - 1); // 0x40 + let z_plus_1 = _mm256_set1_epi8(b'Z' as i8 + 1); // 0x5B + let case_bit = _mm256_set1_epi8(0x20); + + let len = buf.len(); + let mut i = 0; + while i + 32 <= len { + let p = buf.as_mut_ptr().add(i) as *mut __m256i; + let v = _mm256_loadu_si256(p as *const __m256i); + // Signed compares are correct here because all uppercase ASCII bytes + // are < 0x80; bytes >= 0x80 appear negative and are excluded from the mask. + let gt_a = _mm256_cmpgt_epi8(v, a_minus_1); // v > 0x40 + let lt_z = _mm256_cmpgt_epi8(z_plus_1, v); // 0x5B > v + let mask = _mm256_and_si256(gt_a, lt_z); + let flip = _mm256_and_si256(mask, case_bit); + let out = _mm256_xor_si256(v, flip); + _mm256_storeu_si256(p, out); + i += 32; + } + // Scalar tail (also covers buffers shorter than 32 bytes). + ascii_lower_scalar(&mut buf[i..]); +} + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "sse2")] +unsafe fn ascii_lower_sse2(buf: &mut [u8]) { + use std::arch::x86_64::*; + + let a_minus_1 = _mm_set1_epi8(b'A' as i8 - 1); + let z_plus_1 = _mm_set1_epi8(b'Z' as i8 + 1); + let case_bit = _mm_set1_epi8(0x20); + + let len = buf.len(); + let mut i = 0; + while i + 16 <= len { + let p = buf.as_mut_ptr().add(i) as *mut __m128i; + let v = _mm_loadu_si128(p as *const __m128i); + let gt_a = _mm_cmpgt_epi8(v, a_minus_1); + let lt_z = _mm_cmpgt_epi8(z_plus_1, v); + let mask = _mm_and_si128(gt_a, lt_z); + let flip = _mm_and_si128(mask, case_bit); + let out = _mm_xor_si128(v, flip); + _mm_storeu_si128(p, out); + i += 16; + } + ascii_lower_scalar(&mut buf[i..]); +} + +#[cfg(target_arch = "aarch64")] +unsafe fn ascii_lower_neon(buf: &mut [u8]) { + use std::arch::aarch64::*; + + let a_minus_1 = vdupq_n_u8(b'A' - 1); + let z_plus_1 = vdupq_n_u8(b'Z' + 1); + let case_bit = vdupq_n_u8(0x20); + + let len = buf.len(); + let mut i = 0; + while i + 16 <= len { + let p = buf.as_mut_ptr().add(i); + let v = vld1q_u8(p); + // Unsigned compares on aarch64 — directly available. + let gt_a = vcgtq_u8(v, a_minus_1); // v > A-1 → v >= A + let lt_z = vcltq_u8(v, z_plus_1); // v < Z+1 → v <= Z + let mask = vandq_u8(gt_a, lt_z); + let flip = vandq_u8(mask, case_bit); + let out = veorq_u8(v, flip); + vst1q_u8(p, out); + i += 16; + } + ascii_lower_scalar(&mut buf[i..]); +} + +#[cfg(test)] +mod tests { + use super::*; + + fn scalar_reference(input: &[u8]) -> Vec { + let mut out = input.to_vec(); + ascii_lower_scalar(&mut out); + out + } + + #[test] + fn empty() { + let mut buf: [u8; 0] = []; + ascii_lower(&mut buf); + } + + #[test] + fn matches_scalar_on_random_ascii() { + // Mix of upper, lower, digits, symbols across many lengths covering + // sub-block, exact-block, and post-block tails for both 16- and 32-byte + // SIMD widths. + let mut data: Vec = (0..200u32) + .map(|i| { + let c = i as u8; + // Cycle through printable ASCII. + 0x20 + (c % 0x5F) + }) + .collect(); + let expected = scalar_reference(&data); + ascii_lower(&mut data); + assert_eq!(data, expected); + } + + #[test] + fn matches_scalar_at_critical_lengths() { + for len in [ + 0, 1, 7, 15, 16, 17, 31, 32, 33, 47, 48, 63, 64, 65, 128, 129, + ] { + let mut data: Vec = (0..len as u8).map(|i| b'A' + (i % 26)).collect(); + let expected = scalar_reference(&data); + ascii_lower(&mut data); + assert_eq!(data, expected, "len={len}"); + } + } + + #[test] + fn leaves_high_bytes_untouched() { + // Ensures SIMD masks correctly exclude bytes >= 0x80 (UTF-8 continuation + // bytes) — defensive even though the gate is meant to filter these out. + let seed: Vec = vec![ + b'A', b'b', 0xC3, 0xA9, b'Z', 0xE2, 0x82, 0xAC, b'Q', 0x80, 0xFF, + ]; + // Repeat to cross SIMD block boundaries. + let mut data = seed.repeat(8); + let expected = scalar_reference(&data); + ascii_lower(&mut data); + assert_eq!(data, expected); + } + + #[test] + fn idempotent() { + let mut data = b"Hello, World! THE QUICK BROWN FOX 1234 JUMPS OVER 0 LAZY DOGS.".to_vec(); + ascii_lower(&mut data); + let once = data.clone(); + ascii_lower(&mut data); + assert_eq!(data, once); + } +}