Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions tokenizers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ harness = false
name = "truncation_benchmark"
harness = false

[[bench]]
name = "whitespace_pretok_benchmark"
harness = false

[dependencies]
rand = "0.9"
onig = { version = "6.5.1", default-features = false, optional = true }
Expand Down
96 changes: 96 additions & 0 deletions tokenizers/benches/whitespace_pretok_benchmark.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#[macro_use]
extern crate criterion;

use criterion::{BenchmarkId, Criterion, Throughput};
use std::hint::black_box;
use tokenizers::pre_tokenizers::whitespace::{ManualWhitespaceSplit, Whitespace};
use tokenizers::{PreTokenizedString, PreTokenizer};

fn bench_pretokenizer(c: &mut Criterion) {
let data = std::fs::read_to_string("data/big.txt").unwrap();
let lines: Vec<&str> = data.lines().collect();

let mut group = c.benchmark_group("whitespace-pretok");
group.throughput(Throughput::Bytes(data.len() as u64));

// Full corpus as a single string
group.bench_function("regex/full-corpus", |b| {
let pretok = Whitespace {};
b.iter(|| {
let mut pre = PreTokenizedString::from(black_box(data.as_str()));
pretok.pre_tokenize(&mut pre).unwrap();
pre
})
});

group.bench_function("manual/full-corpus", |b| {
let pretok = ManualWhitespaceSplit {};
b.iter(|| {
let mut pre = PreTokenizedString::from(black_box(data.as_str()));
pretok.pre_tokenize(&mut pre).unwrap();
pre
})
});

// Line-by-line (many short strings — tests per-call overhead)
group.bench_function("regex/line-by-line", |b| {
let pretok = Whitespace {};
b.iter(|| {
for line in &lines {
let mut pre = PreTokenizedString::from(black_box(*line));
pretok.pre_tokenize(&mut pre).unwrap();
black_box(&pre);
}
})
});

group.bench_function("manual/line-by-line", |b| {
let pretok = ManualWhitespaceSplit {};
b.iter(|| {
for line in &lines {
let mut pre = PreTokenizedString::from(black_box(*line));
pretok.pre_tokenize(&mut pre).unwrap();
black_box(&pre);
}
})
});

group.finish();

// --- Scaling with input size ---

let mut group = c.benchmark_group("whitespace-pretok-scaling");

for size in [100, 1_000, 10_000, 100_000] {
let input: String = data.chars().take(size).collect();
group.throughput(Throughput::Bytes(input.len() as u64));

group.bench_with_input(BenchmarkId::new("regex", size), &input, |b, input| {
let pretok = Whitespace {};
b.iter(|| {
let mut pre = PreTokenizedString::from(black_box(input.as_str()));
pretok.pre_tokenize(&mut pre).unwrap();
pre
})
});

group.bench_with_input(BenchmarkId::new("manual", size), &input, |b, input| {
let pretok = ManualWhitespaceSplit {};
b.iter(|| {
let mut pre = PreTokenizedString::from(black_box(input.as_str()));
pretok.pre_tokenize(&mut pre).unwrap();
pre
})
});
}

group.finish();
}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bench a full encode_batch please


criterion_group! {
name = whitespace_pretok;
config = Criterion::default().sample_size(50);
targets = bench_pretokenizer
}

criterion_main!(whitespace_pretok);
196 changes: 196 additions & 0 deletions tokenizers/src/pre_tokenizers/whitespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::sync::LazyLock;

use regex::Regex;

use crate::pattern::Pattern;
use crate::tokenizer::{
pattern::Invert, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior,
};
Expand Down Expand Up @@ -40,6 +41,85 @@ impl PreTokenizer for WhitespaceSplit {
}
}

#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[macro_rules_attribute(impl_serde_type!)]
pub struct ManualWhitespaceSplit;

impl PreTokenizer for ManualWhitespaceSplit {
fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {
pretokenized.split(|_, normalized| {
normalized.split(WhiteSpacePattern, SplitDelimiterBehavior::Removed)
})
}
}

#[derive(Clone, Copy, Eq, PartialEq)]
enum CharType {
Whitespace,
Word,
Symbol,
}

struct WhiteSpacePattern;

impl Pattern for WhiteSpacePattern {
fn find_matches(&self, inside: &str) -> Result<Vec<(crate::Offsets, bool)>> {
if inside.is_empty() {
return Ok(vec![((0, 0), false)]);
}

let mut matches = Vec::new();
let mut span_start = 0;
let mut prev_type: Option<CharType> = None;

for (i, ch) in inside.char_indices() {
let ct = classify(ch);

if let Some(pt) = prev_type {
if pt != ct {
// Emit the previous span:
// - whitespace spans are non-matches (false)
// - word/symbol spans are matches (true)
matches.push(((span_start, i), pt == CharType::Whitespace));
span_start = i;
}
}
prev_type = Some(ct);
}

// Emit the final span
if let Some(pt) = prev_type {
matches.push(((span_start, inside.len()), pt == CharType::Whitespace));
}

Ok(matches)
}
}

fn classify(ch: char) -> CharType {
if ch.is_whitespace() {
CharType::Whitespace
} else if is_word_char(ch) {
CharType::Word
} else {
CharType::Symbol
}
}

/// Matches the same characters as the `\w` regex class (Unicode-aware).
/// This is: Alphabetic + Nd (decimal digit) + Pc (connector punctuation) +
/// M (marks) + Join_Control — NOT Nl/No (which Rust's is_alphanumeric includes).
fn is_word_char(ch: char) -> bool {
use unicode_categories::UnicodeCategories;

ch.is_alphabetic()
|| ch.is_number_decimal_digit()
|| ch.is_punctuation_connector()
|| ch.is_mark()
|| ch == '\u{200c}' // Zero-Width Non-Joiner
|| ch == '\u{200d}' // Zero-Width Joiner
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -102,4 +182,120 @@ mod tests {
);
}
}

#[test]
fn assert_equivalent() {
let test_cases = vec![
"Hello world!",
"How are you doing?",
"This is a test with numbers 123 and symbols @#$%",
"Multiple spaces",
"Tabs\tand\nnewlines",
"Unicode: café résumé naïve",
"Mixed: Hello123!@# world",
"Edge cases: a.b,c;d:e",
"Empty string:",
"Only spaces: ",
"Only symbols: !@#$%",
"Only words: hello world",
"Numbers: 123 456 789",
"Underscores: hello_world test_case",
"Special chars: αβγ δέζ ηθι",
];

for test_case in test_cases {
let mut original = PreTokenizedString::from(test_case);
let mut manual = PreTokenizedString::from(test_case);

let original_pretok = Whitespace {};
let manual_pretok = ManualWhitespaceSplit {};

original_pretok.pre_tokenize(&mut original).unwrap();
manual_pretok.pre_tokenize(&mut manual).unwrap();

let original_splits = original
.get_splits(OffsetReferential::Original, OffsetType::Byte)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>();

let manual_splits = manual
.get_splits(OffsetReferential::Original, OffsetType::Byte)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>();

assert_eq!(
original_splits, manual_splits,
"Mismatch for test case: '{}'",
test_case
);
}
}

#[test]
fn manual_edge_cases() {
let pretok = ManualWhitespaceSplit {};

// Test various edge cases
let edge_cases = vec![
("", vec![]),
(" ", vec![]),
(" ", vec![]),
("a", vec![("a", (0, 1))]),
("!", vec![("!", (0, 1))]),
("a!", vec![("a", (0, 1)), ("!", (1, 2))]),
("!a", vec![("!", (0, 1)), ("a", (1, 2))]),
("a b", vec![("a", (0, 1)), ("b", (2, 3))]),
("a b", vec![("a", (0, 1)), ("b", (3, 4))]),
("a\tb", vec![("a", (0, 1)), ("b", (2, 3))]),
("a\nb", vec![("a", (0, 1)), ("b", (2, 3))]),
("a\r\nb", vec![("a", (0, 1)), ("b", (3, 4))]),
];

for (input, expected) in edge_cases {
let mut pretokenized = PreTokenizedString::from(input);
pretok.pre_tokenize(&mut pretokenized).unwrap();
let result = pretokenized
.get_splits(OffsetReferential::Original, OffsetType::Byte)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>();
assert_eq!(result, expected, "Failed for input: '{}'", input);
}
}

#[test]
fn assert_equivalent_xnli() {
let data = std::fs::read_to_string("data/xnli.txt").unwrap();
let original_pretok = Whitespace {};
let manual_pretok = ManualWhitespaceSplit {};

for (i, line) in data.lines().enumerate() {
let mut original = PreTokenizedString::from(line);
let mut manual = PreTokenizedString::from(line);

original_pretok.pre_tokenize(&mut original).unwrap();
manual_pretok.pre_tokenize(&mut manual).unwrap();

let original_splits = original
.get_splits(OffsetReferential::Original, OffsetType::Byte)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>();
let manual_splits = manual
.get_splits(OffsetReferential::Original, OffsetType::Byte)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>();

assert_eq!(
original_splits,
manual_splits,
"Mismatch on line {}: '{}'",
i,
&line.chars().take(80).collect::<String>(),
);
}
}
}
Loading