Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions tokenizers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ harness = false
name = "ci_benchmark"
harness = false

[[bench]]
name = "ascii_lower_benchmark"
harness = false

[dependencies]
rand = "0.9"
onig = { version = "6.5.1", default-features = false, optional = true }
Expand Down
77 changes: 77 additions & 0 deletions tokenizers/benches/ascii_lower_benchmark.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
//! Microbenchmark for the ASCII lowercase fast path used by the `Lowercase`
//! normalizer. Compares the SIMD-dispatched `utils::simd::ascii_lower` against
//! the scalar reference and against the previous Unicode-aware path
//! (`char::to_lowercase` per char) on representative buffer sizes.

#[macro_use]
extern crate criterion;

use criterion::{Criterion, Throughput};
use std::hint::black_box;
use tokenizers::utils::simd::ascii_lower;

fn make_buffer(len: usize) -> Vec<u8> {
// Cycle through printable ASCII (mix of upper, lower, digits, punctuation)
// so the upper-case branch fires on roughly 26/95 of bytes.
(0..len).map(|i| 0x20u8 + (i as u8 % 0x5F)).collect()
}

fn scalar_lower(buf: &mut [u8]) {
for b in buf {
if b.is_ascii_uppercase() {
*b |= 0x20;
}
}
}

fn unicode_lower(buf: &str) -> String {
// Mirrors what `NormalizedString::lowercase` did before the fast path: per
// `char` UTF-8 decode + Unicode case folding, ignoring alignment bookkeeping.
let mut out = String::with_capacity(buf.len());
for c in buf.chars() {
for lc in c.to_lowercase() {
out.push(lc);
}
}
out
}

pub fn bench_ascii_lower(c: &mut Criterion) {
for &len in &[64usize, 1024, 16 * 1024, 256 * 1024] {
let mut group = c.benchmark_group(format!("ascii_lower/{len}B"));
group.throughput(Throughput::Bytes(len as u64));

let original = make_buffer(len);

group.bench_function("simd", |b| {
let mut buf = original.clone();
b.iter(|| {
ascii_lower(black_box(&mut buf));
});
});

group.bench_function("scalar", |b| {
let mut buf = original.clone();
b.iter(|| {
scalar_lower(black_box(&mut buf));
});
});

// Stand-in for the pre-SIMD code path.
let s = String::from_utf8(original.clone()).unwrap();
group.bench_function("unicode_chars", |b| {
b.iter(|| {
black_box(unicode_lower(black_box(&s)));
});
});

group.finish();
}
}

criterion_group! {
name = ascii_lower_bench;
config = Criterion::default().sample_size(50);
targets = bench_ascii_lower
}
criterion_main!(ascii_lower_bench);
49 changes: 49 additions & 0 deletions tokenizers/src/tokenizer/normalizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,16 @@ impl NormalizedString {

/// Lowercase
pub fn lowercase(&mut self) -> &mut Self {
// ASCII fast path: each `A`..=`Z` becomes a single-byte `a`..=`z`,
// so byte length and per-byte alignments are unchanged. We can mutate
// bytes in place and skip the Unicode-aware `transform` rebuild.
if self.normalized.is_ascii() {
// Safety: `ascii_lower` only flips `0x20` on bytes already in
// `b'A'..=b'Z'` (all < 0x80), so the result remains valid UTF-8.
let bytes = unsafe { self.normalized.as_bytes_mut() };
crate::utils::simd::ascii_lower(bytes);
return self;
}
let mut new_chars: Vec<(char, isize)> = vec![];
self.for_each(|c| {
c.to_lowercase().enumerate().for_each(|(index, c)| {
Expand Down Expand Up @@ -2289,6 +2299,45 @@ mod tests {
assert_eq!(s.get(), "a...");
}

#[test]
fn lowercase_ascii_fast_path_preserves_alignments() {
// After a non-trivial transform (here NFKD on a ligature) the alignments
// map several normalized bytes back onto fewer original bytes. The ASCII
// fast path must leave that mapping byte-for-byte unchanged.
let mut n = NormalizedString::from("ABC\u{FB00}DEF"); // "ABCffDEF"; ff -> "ff" via NFKD
n.nfkd();
// Sanity: result is now all ASCII so the fast path will trigger.
assert!(n.get().is_ascii());

let bytes_before = n.normalized.clone();
let alignments_before = n.alignments.clone();
let original_before = n.original.clone();
let shift_before = n.original_shift;

n.lowercase();

assert_eq!(
n.normalized,
bytes_before.to_lowercase(),
"bytes mismatch fast vs ASCII to_lowercase"
);
assert_eq!(n.alignments, alignments_before, "alignments mutated");
assert_eq!(n.original, original_before, "original mutated");
assert_eq!(n.original_shift, shift_before, "original_shift mutated");
}

#[test]
fn lowercase_ascii_matches_unicode_path_byte_for_byte() {
// Cross-check against char::to_lowercase on every printable ASCII byte:
// the fast path must produce exactly the same bytes the slow path would.
let input: String = (0x20u8..0x7F).map(|b| b as char).collect();
let mut fast = NormalizedString::from(input.as_str());
fast.lowercase();

let expected: String = input.chars().flat_map(|c| c.to_lowercase()).collect();
assert_eq!(fast.get(), expected);
}

#[test]
fn test_append_after_clear() {
let mut n = NormalizedString::from("Hello");
Expand Down
1 change: 1 addition & 0 deletions tokenizers/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub mod iter;
pub mod padding;
pub mod parallelism;
pub(crate) mod progress;
pub mod simd;
pub mod truncation;

// Re-export ProgressFormat for public API
Expand Down
227 changes: 227 additions & 0 deletions tokenizers/src/utils/simd.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
//! SIMD helpers for ASCII fast paths.
//!
//! Each public function dispatches at runtime (via `is_x86_feature_detected!`
//! on x86_64; aarch64 always has NEON under the stable target ABI) and falls
//! back to a scalar implementation on other architectures.

/// Lowercase ASCII letters (`A`..=`Z` → `a`..=`z`) in place. Bytes outside that
/// range are left untouched, so it is safe to call on any byte slice — but for
/// best speed it should be guarded by a `is_ascii()` check upstream so the
/// caller can also skip Unicode-aware logic.
#[inline]
pub fn ascii_lower(buf: &mut [u8]) {
#[cfg(target_arch = "x86_64")]
{
// Only enable the AVX-512 path on CPUs where the AVX-512 license-mode
// downclock is negligible. `avx512fp16` is used as a proxy: it is
// present on Intel Sapphire Rapids+ and on AMD Zen 4/Zen 5, and is
// absent on Skylake-X / Cascade Lake / Cooper Lake / Ice Lake-SP /
// Rocket Lake — i.e. exactly the generations where 512-bit ops trigger
// significant frequency throttling.
if std::is_x86_feature_detected!("avx512f")
&& std::is_x86_feature_detected!("avx512bw")
&& std::is_x86_feature_detected!("avx512fp16")
{
unsafe { return ascii_lower_avx512(buf) };
}
if std::is_x86_feature_detected!("avx2") {
unsafe { return ascii_lower_avx2(buf) };
}
// SSE2 is part of the x86_64 baseline; always available.
unsafe { return ascii_lower_sse2(buf) };
}
#[cfg(target_arch = "aarch64")]
{
// NEON is mandatory on the stable aarch64 ABI; no runtime check needed.
unsafe { return ascii_lower_neon(buf) };
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
ascii_lower_scalar(buf);
}
}

#[inline(always)]
fn ascii_lower_scalar(buf: &mut [u8]) {
for b in buf {
if b.is_ascii_uppercase() {
*b |= 0x20;
}
}
}

#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f,avx512bw")]
unsafe fn ascii_lower_avx512(buf: &mut [u8]) {
use std::arch::x86_64::*;

let a_minus_1 = _mm512_set1_epi8(b'A' as i8 - 1);
let z_plus_1 = _mm512_set1_epi8(b'Z' as i8 + 1);
let case_bit = _mm512_set1_epi8(0x20);

let len = buf.len();
let mut i = 0;
while i + 64 <= len {
let p = buf.as_mut_ptr().add(i) as *mut __m512i;
let v = _mm512_loadu_si512(p as *const __m512i);
// `__mmask64` is `u64` in Rust; the two range checks AND together as a
// plain bitwise op. Signed compares are correct here for the same
// reason as the SSE2/AVX2 paths: ASCII bytes are < 0x80.
let gt_a = _mm512_cmpgt_epi8_mask(v, a_minus_1);
let lt_z = _mm512_cmpgt_epi8_mask(z_plus_1, v);
let mask = gt_a & lt_z;
let mask_vec = _mm512_movm_epi8(mask);
let flip = _mm512_and_si512(mask_vec, case_bit);
let out = _mm512_xor_si512(v, flip);
_mm512_storeu_si512(p, out);
i += 64;
}
ascii_lower_scalar(&mut buf[i..]);
}

#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn ascii_lower_avx2(buf: &mut [u8]) {
use std::arch::x86_64::*;

let a_minus_1 = _mm256_set1_epi8(b'A' as i8 - 1); // 0x40
let z_plus_1 = _mm256_set1_epi8(b'Z' as i8 + 1); // 0x5B
let case_bit = _mm256_set1_epi8(0x20);

let len = buf.len();
let mut i = 0;
while i + 32 <= len {
let p = buf.as_mut_ptr().add(i) as *mut __m256i;
let v = _mm256_loadu_si256(p as *const __m256i);
// Signed compares are correct here because all uppercase ASCII bytes
// are < 0x80; bytes >= 0x80 appear negative and are excluded from the mask.
let gt_a = _mm256_cmpgt_epi8(v, a_minus_1); // v > 0x40
let lt_z = _mm256_cmpgt_epi8(z_plus_1, v); // 0x5B > v
let mask = _mm256_and_si256(gt_a, lt_z);
let flip = _mm256_and_si256(mask, case_bit);
let out = _mm256_xor_si256(v, flip);
_mm256_storeu_si256(p, out);
i += 32;
}
// Scalar tail (also covers buffers shorter than 32 bytes).
ascii_lower_scalar(&mut buf[i..]);
}

#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse2")]
unsafe fn ascii_lower_sse2(buf: &mut [u8]) {
use std::arch::x86_64::*;

let a_minus_1 = _mm_set1_epi8(b'A' as i8 - 1);
let z_plus_1 = _mm_set1_epi8(b'Z' as i8 + 1);
let case_bit = _mm_set1_epi8(0x20);

let len = buf.len();
let mut i = 0;
while i + 16 <= len {
let p = buf.as_mut_ptr().add(i) as *mut __m128i;
let v = _mm_loadu_si128(p as *const __m128i);
let gt_a = _mm_cmpgt_epi8(v, a_minus_1);
let lt_z = _mm_cmpgt_epi8(z_plus_1, v);
let mask = _mm_and_si128(gt_a, lt_z);
let flip = _mm_and_si128(mask, case_bit);
let out = _mm_xor_si128(v, flip);
_mm_storeu_si128(p, out);
i += 16;
}
ascii_lower_scalar(&mut buf[i..]);
}

#[cfg(target_arch = "aarch64")]
unsafe fn ascii_lower_neon(buf: &mut [u8]) {
use std::arch::aarch64::*;

let a_minus_1 = vdupq_n_u8(b'A' - 1);
let z_plus_1 = vdupq_n_u8(b'Z' + 1);
let case_bit = vdupq_n_u8(0x20);

let len = buf.len();
let mut i = 0;
while i + 16 <= len {
let p = buf.as_mut_ptr().add(i);
let v = vld1q_u8(p);
// Unsigned compares on aarch64 — directly available.
let gt_a = vcgtq_u8(v, a_minus_1); // v > A-1 → v >= A
let lt_z = vcltq_u8(v, z_plus_1); // v < Z+1 → v <= Z
let mask = vandq_u8(gt_a, lt_z);
let flip = vandq_u8(mask, case_bit);
let out = veorq_u8(v, flip);
vst1q_u8(p, out);
i += 16;
}
ascii_lower_scalar(&mut buf[i..]);
}

#[cfg(test)]
mod tests {
use super::*;

fn scalar_reference(input: &[u8]) -> Vec<u8> {
let mut out = input.to_vec();
ascii_lower_scalar(&mut out);
out
}

#[test]
fn empty() {
let mut buf: [u8; 0] = [];
ascii_lower(&mut buf);
}

#[test]
fn matches_scalar_on_random_ascii() {
// Mix of upper, lower, digits, symbols across many lengths covering
// sub-block, exact-block, and post-block tails for both 16- and 32-byte
// SIMD widths.
let mut data: Vec<u8> = (0..200u32)
.map(|i| {
let c = i as u8;
// Cycle through printable ASCII.
0x20 + (c % 0x5F)
})
.collect();
let expected = scalar_reference(&data);
ascii_lower(&mut data);
assert_eq!(data, expected);
}

#[test]
fn matches_scalar_at_critical_lengths() {
for len in [
0, 1, 7, 15, 16, 17, 31, 32, 33, 47, 48, 63, 64, 65, 128, 129,
] {
let mut data: Vec<u8> = (0..len as u8).map(|i| b'A' + (i % 26)).collect();
let expected = scalar_reference(&data);
ascii_lower(&mut data);
assert_eq!(data, expected, "len={len}");
}
}

#[test]
fn leaves_high_bytes_untouched() {
// Ensures SIMD masks correctly exclude bytes >= 0x80 (UTF-8 continuation
// bytes) — defensive even though the gate is meant to filter these out.
let seed: Vec<u8> = vec![
b'A', b'b', 0xC3, 0xA9, b'Z', 0xE2, 0x82, 0xAC, b'Q', 0x80, 0xFF,
];
// Repeat to cross SIMD block boundaries.
let mut data = seed.repeat(8);
let expected = scalar_reference(&data);
ascii_lower(&mut data);
assert_eq!(data, expected);
}

#[test]
fn idempotent() {
let mut data = b"Hello, World! THE QUICK BROWN FOX 1234 JUMPS OVER 0 LAZY DOGS.".to_vec();
ascii_lower(&mut data);
let once = data.clone();
ascii_lower(&mut data);
assert_eq!(data, once);
}
}