diff --git a/tokenizers/src/pre_tokenizers/byte_level.rs b/tokenizers/src/pre_tokenizers/byte_level.rs index 8bc0f30af..4ed0be3c3 100644 --- a/tokenizers/src/pre_tokenizers/byte_level.rs +++ b/tokenizers/src/pre_tokenizers/byte_level.rs @@ -131,6 +131,16 @@ impl PreTokenizer for ByteLevel { })?; pretokenized.normalize(|normalized| { let s = normalized.get(); + // Fast path: bytes in `0x21..=0x7E` (printable ASCII excluding space) + // map to the char with the same code point in `BYTES_CHAR`, i.e. + // `BYTES_CHAR[b] == b as char`. So for any token whose bytes all sit + // in that range, the per-byte transform produces an output that is + // byte-identical to the input and leaves `alignments` unchanged. We + // can therefore return without rebuilding anything. The `iter().all` + // predicate is trivially auto-vectorized by the compiler on stable. + if !s.is_empty() && s.as_bytes().iter().all(|&b| (0x21..=0x7E).contains(&b)) { + return Ok(()); + } let mut transformations: Vec<(char, isize)> = Vec::with_capacity(s.len()); for (i, cur_char) in s.char_indices() { let size = cur_char.len_utf8(); @@ -568,6 +578,56 @@ mod tests { ); } + #[test] + fn printable_ascii_fast_path_matches_slow_path() { + // Tokens whose bytes are all in `0x21..=0x7E` exercise the fast path. + // Their normalized form must be byte-identical to the input and their + // offsets must still tile the input contiguously. + let inputs = [ + "Hello", + "!", + "?world!", + "abc123XYZ_+-=*/<>", + "a", // 1-byte boundary + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ", // > 32 bytes to cross any auto-vectorized chunk boundary + ]; + let bytelevel = ByteLevel::default() + .add_prefix_space(false) + .use_regex(false); + for s in inputs { + let mut pretok = PreTokenizedString::from(s); + bytelevel.pre_tokenize(&mut pretok).unwrap(); + let splits: Vec<_> = pretok + .get_splits(OffsetReferential::Original, OffsetType::Byte) + .into_iter() + .map(|(t, o, _)| (t.to_string(), o)) + .collect(); + assert_eq!( + splits, + vec![(s.to_string(), (0, s.len()))], + "fast path mangled token: {s:?}" + ); + } + } + + #[test] + fn fast_path_does_not_swallow_non_printable_bytes() { + // Tokens containing a byte outside `0x21..=0x7E` (here a leading space) + // must still hit the slow path and get the GPT-2 byte→char mapping + // (' ' -> 'Ġ', i.e. U+0120). + let bytelevel = ByteLevel::default() + .add_prefix_space(false) + .use_regex(false); + let mut pretok = PreTokenizedString::from(" hi"); + bytelevel.pre_tokenize(&mut pretok).unwrap(); + let normalized: Vec<_> = pretok + .get_splits(OffsetReferential::Normalized, OffsetType::Byte) + .into_iter() + .map(|(t, _, _)| t.to_string()) + .collect(); + assert_eq!(normalized, vec!["Ġhi".to_string()]); + } + #[test] fn deserialization() { // Before use_regex