diff --git a/tokenizers/Cargo.toml b/tokenizers/Cargo.toml index 0e937f3cc5..aafedb8251 100644 --- a/tokenizers/Cargo.toml +++ b/tokenizers/Cargo.toml @@ -93,6 +93,7 @@ getrandom = { version = "0.3" } esaxx-rs = { version = "0.1.10", default-features = false, features = [] } monostate = "0.1.12" ahash = { version = "0.8.11", features = ["serde"] } +rustc-hash = "2" dary_heap = { version = "0.3.6", features = ["serde"] } compact_str = { version = "0.9", features = ["serde"] } diff --git a/tokenizers/src/models/bpe/mod.rs b/tokenizers/src/models/bpe/mod.rs index f0d40b2df6..603c29c815 100644 --- a/tokenizers/src/models/bpe/mod.rs +++ b/tokenizers/src/models/bpe/mod.rs @@ -1,4 +1,5 @@ //! [Byte Pair Encoding](https://www.aclweb.org/anthology/P16-1162/) model. +use rustc_hash::FxHashMap; use std::{iter, mem}; mod model; @@ -8,6 +9,80 @@ mod word; type Pair = (u32, u32); +/// Packs a `(u32, u32)` pair into a single `u64` for faster hashing. +#[inline] +fn pack_pair(pair: &Pair) -> u64 { + (pair.0 as u64) << 32 | pair.1 as u64 +} + +/// Unpacks a `u64` back into a `(u32, u32)` pair. +#[inline] +fn unpack_pair(packed: u64) -> Pair { + ((packed >> 32) as u32, packed as u32) +} + +/// A merge-lookup map that packs `(u32, u32)` pair keys into single `u64` values +/// for faster hashing (single FxHash multiply instead of hashing two fields). +/// +/// Values are `(rank, new_id)` tuples. +#[derive(Clone, Debug)] +pub(crate) struct MergeMap { + inner: FxHashMap, +} + +impl MergeMap { + #[allow(dead_code)] + pub fn new() -> Self { + MergeMap { + inner: FxHashMap::default(), + } + } + + pub fn with_capacity(cap: usize) -> Self { + MergeMap { + inner: FxHashMap::with_capacity_and_hasher(cap, Default::default()), + } + } + + #[inline] + /// Get `(rank, new_id)` for a given `Pair` in the map. + pub fn get(&self, pair: &Pair) -> Option<&(u32, u32)> { + self.inner.get(&pack_pair(pair)) + } + + /// Insert `(rank, new_id)` for a given `Pair` in the map. + pub fn insert(&mut self, pair: Pair, value: (u32, u32)) -> Option<(u32, u32)> { + self.inner.insert(pack_pair(&pair), value) + } + + pub fn len(&self) -> usize { + self.inner.len() + } + + /// Iterate over `(Pair, &(rank, new_id))`. + pub fn iter(&self) -> impl Iterator { + self.inner.iter().map(|(k, v)| (unpack_pair(*k), v)) + } +} + +impl PartialEq for MergeMap { + fn eq(&self, other: &Self) -> bool { + self.inner == other.inner + } +} + +impl std::iter::FromIterator<(Pair, (u32, u32))> for MergeMap { + fn from_iter>(iter: I) -> Self { + let iter = iter.into_iter(); + let (lo, _) = iter.size_hint(); + let mut map = MergeMap::with_capacity(lo); + for (pair, val) in iter { + map.insert(pair, val); + } + map + } +} + /// Errors that can be encountered while using or constructing a `BPE` model. #[derive(thiserror::Error, Debug)] pub enum Error { diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index c0e4f7d84d..a1b072e8be 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -1,10 +1,11 @@ -use super::{super::OrderedVocabIter, trainer::BpeTrainer, Error, Pair, Word}; +use super::{super::OrderedVocabIter, trainer::BpeTrainer, Error, MergeMap, Pair, Word}; use crate::tokenizer::{Model, Result, Token}; use crate::utils::cache::{Cache, DEFAULT_CACHE_CAPACITY, MAX_LENGTH}; use crate::utils::iter::ResultShunt; use ahash::AHashMap; use serde_json::Value; use std::borrow::Cow; +use std::cell::RefCell; use std::collections::HashMap; use std::str::from_utf8_unchecked; @@ -15,9 +16,12 @@ use std::{ path::{Path, PathBuf}, }; +thread_local! { + static TL_WORD: RefCell = RefCell::new(Word::with_capacity(64)); +} + pub type Vocab = AHashMap; type VocabR = AHashMap; -pub type MergeMap = AHashMap; pub type Merges = Vec<(String, String)>; struct Config { @@ -495,17 +499,23 @@ impl BPE { )]); } } - if let Some(ref hit) = self.cache.as_ref().and_then(|c| c.get(sequence)) { - return Ok(self.word_to_tokens(hit).collect()); - } - let word = self.merge_word(sequence)?; - let ret = self.word_to_tokens(&word).collect(); - if let Some(ref cache) = self.cache { - if sequence.len() < MAX_LENGTH { - cache.set(sequence.to_owned(), word); + TL_WORD.with(|w| { + let mut word = w.borrow_mut(); + word.clear(); + if let Some(ref cache) = self.cache { + if cache.get_into(sequence, &mut word) { + return Ok(self.word_to_tokens(&word).collect()); + } } - } - Ok(ret) + let word = self.merge_word(sequence)?; + let ret = self.word_to_tokens(&word).collect(); + if let Some(ref cache) = self.cache { + if sequence.len() < MAX_LENGTH { + cache.set(sequence.to_owned(), word); + } + } + Ok(ret) + }) } } @@ -566,12 +576,12 @@ impl Model for BPE { .iter() .collect(); let mut merges_file = File::create(&merges_path)?; - let mut merges: Vec<(&Pair, &u32)> = self + let mut merges: Vec<(Pair, u32)> = self .merges .iter() - .map(|(pair, (rank, _))| (pair, rank)) + .map(|(pair, (rank, _))| (pair, *rank)) .collect(); - merges.sort_unstable_by_key(|k| *k.1); + merges.sort_unstable_by_key(|k| k.1); merges_file.write_all(b"#version: 0.2\n")?; merges_file.write_all( &merges diff --git a/tokenizers/src/models/bpe/serialization.rs b/tokenizers/src/models/bpe/serialization.rs index 98cf549445..c28f2b184f 100644 --- a/tokenizers/src/models/bpe/serialization.rs +++ b/tokenizers/src/models/bpe/serialization.rs @@ -24,12 +24,12 @@ impl Serialize for BPE { model.serialize_field("ignore_merges", &self.ignore_merges)?; // Then the large ones - let mut merges: Vec<(&Pair, &u32)> = self + let mut merges: Vec<(Pair, u32)> = self .merges .iter() - .map(|(pair, (rank, _))| (pair, rank)) + .map(|(pair, (rank, _))| (pair, *rank)) .collect(); - merges.sort_unstable_by_key(|k| *k.1); + merges.sort_unstable_by_key(|k| k.1); let merges = merges .into_iter() .map(|(pair, _)| (self.vocab_r[&pair.0].clone(), self.vocab_r[&pair.1].clone())) diff --git a/tokenizers/src/models/bpe/trainer.rs b/tokenizers/src/models/bpe/trainer.rs index df68c655e9..6958362407 100644 --- a/tokenizers/src/models/bpe/trainer.rs +++ b/tokenizers/src/models/bpe/trainer.rs @@ -677,7 +677,8 @@ impl Trainer for BpeTrainer { #[cfg(test)] mod tests { - use super::{BpeTrainer, Pair, BPE}; + use super::{BpeTrainer, BPE}; + use crate::models::bpe::MergeMap; use ahash::AHashMap; use compact_str::CompactString; @@ -744,7 +745,7 @@ mod tests { // where 'rank' determines the order in which this merge will be applied during // tokenization, and 'id' is the vocab id of the symbol resulting from merging // the pair of symbols in the corresponding key. - let expected_merges: AHashMap = [ + let expected_merges: MergeMap = [ ((17, 11), (0, 22)), // 'r' + 'e' -> 're' ((8, 22), (1, 23)), // 'a' + 're' -> 'are' ((13, 18), (2, 24)), // 'i' + 's' -> 'is' diff --git a/tokenizers/src/models/bpe/word.rs b/tokenizers/src/models/bpe/word.rs index 7bf2dee566..640ec1da60 100644 --- a/tokenizers/src/models/bpe/word.rs +++ b/tokenizers/src/models/bpe/word.rs @@ -1,9 +1,15 @@ -use super::Pair; -use ahash::AHashMap; +use super::{MergeMap, Pair}; +use crate::utils::cache::ExtendFromRef; use dary_heap::QuaternaryHeap; use rand::{rng, Rng}; +use std::cell::RefCell; use std::cmp::Ordering; +thread_local! { + static TL_MERGE_HEAP: RefCell> = RefCell::new(QuaternaryHeap::new()); + static TL_MERGE_SKIP: RefCell> = const { RefCell::new(Vec::new()) }; +} + #[derive(Debug, Eq)] struct Merge { pos: usize, @@ -57,6 +63,13 @@ impl Symbol { pub(super) struct Word { symbols: Vec, } + +impl ExtendFromRef for Word { + fn extend_from_ref(&mut self, other: &Self) { + self.symbols.extend_from_slice(&other.symbols); + } +} + impl std::fmt::Debug for Word { fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { fmt.debug_struct("Word") @@ -85,6 +98,10 @@ impl Word { } } + pub(super) fn clear(&mut self) { + self.symbols.clear(); + } + pub(super) fn add(&mut self, c: u32, byte_len: usize) { let (prev, next) = { let len = self.symbols.len() as isize; @@ -159,91 +176,97 @@ impl Word { changes } - pub(super) fn merge_all(&mut self, merges: &AHashMap, dropout: Option) { - let mut queue = QuaternaryHeap::with_capacity(self.symbols.len()); - let mut skip = Vec::with_capacity(queue.len()); - - queue.extend( - self.symbols - .windows(2) - .enumerate() - .filter_map(|(index, window)| { - let pair = (window[0].c, window[1].c); - merges.get(&pair).map(|m| Merge { - pos: index, - rank: m.0, - new_id: m.1, - }) - }), - ); - - while let Some(top) = queue.pop() { - if dropout.map(|d| rng().random::() < d).unwrap_or(false) { - skip.push(top); - } else { - // Re-insert the skipped elements - queue.extend(skip.drain(..)); - - if self.symbols[top.pos].len == 0 { - continue; - } - // Do nothing if we are the last symbol - if self.symbols[top.pos].next == -1 { - continue; - } - - let next_pos = self.symbols[top.pos].next as usize; - let right = self.symbols[next_pos]; - - // Make sure we are not processing an expired queue entry - let target_new_pair = (self.symbols[top.pos].c, right.c); - if merges - .get(&target_new_pair) - .is_none_or(|(_, new_id)| *new_id != top.new_id) - { - continue; - } - - // Otherwise, let's merge - self.symbols[top.pos].merge_with(&right, top.new_id); - // Tag the right part as removed - self.symbols[next_pos].len = 0; - - // Update `prev` on the new `next` to the current pos - if right.next > -1 && (right.next as usize) < self.symbols.len() { - self.symbols[right.next as usize].prev = top.pos as isize; - } - - // Insert the new pair formed with the previous symbol - let current = &self.symbols[top.pos]; - if current.prev >= 0 { - let prev = current.prev as usize; - let prev_symbol = self.symbols[prev]; - let new_pair = (prev_symbol.c, current.c); - if let Some((rank, new_id)) = merges.get(&new_pair) { - queue.push(Merge { - pos: current.prev as usize, - rank: *rank, - new_id: *new_id, - }); + pub(super) fn merge_all(&mut self, merges: &MergeMap, dropout: Option) { + TL_MERGE_HEAP.with(|heap_cell| { + TL_MERGE_SKIP.with(|skip_cell| { + let mut queue = heap_cell.borrow_mut(); + let mut skip = skip_cell.borrow_mut(); + queue.clear(); + skip.clear(); + + queue.extend( + self.symbols + .windows(2) + .enumerate() + .filter_map(|(index, window)| { + let pair = (window[0].c, window[1].c); + merges.get(&pair).map(|m| Merge { + pos: index, + rank: m.0, + new_id: m.1, + }) + }), + ); + + while let Some(top) = queue.pop() { + if dropout.map(|d| rng().random::() < d).unwrap_or(false) { + skip.push(top); + } else { + // Re-insert the skipped elements + queue.extend(skip.drain(..)); + + if self.symbols[top.pos].len == 0 { + continue; + } + // Do nothing if we are the last symbol + if self.symbols[top.pos].next == -1 { + continue; + } + + let next_pos = self.symbols[top.pos].next as usize; + let right = self.symbols[next_pos]; + + // Make sure we are not processing an expired queue entry + let target_new_pair = (self.symbols[top.pos].c, right.c); + if merges + .get(&target_new_pair) + .is_none_or(|(_, new_id)| *new_id != top.new_id) + { + continue; + } + + // Otherwise, let's merge + self.symbols[top.pos].merge_with(&right, top.new_id); + // Tag the right part as removed + self.symbols[next_pos].len = 0; + + // Update `prev` on the new `next` to the current pos + if right.next > -1 && (right.next as usize) < self.symbols.len() { + self.symbols[right.next as usize].prev = top.pos as isize; + } + + // Insert the new pair formed with the previous symbol + let current = &self.symbols[top.pos]; + if current.prev >= 0 { + let prev = current.prev as usize; + let prev_symbol = self.symbols[prev]; + let new_pair = (prev_symbol.c, current.c); + if let Some((rank, new_id)) = merges.get(&new_pair) { + queue.push(Merge { + pos: current.prev as usize, + rank: *rank, + new_id: *new_id, + }); + } + } + + // Insert the new pair formed with the next symbol + let next = current.next as usize; + if next < self.symbols.len() { + let next_symbol = self.symbols[next]; + let new_pair = (current.c, next_symbol.c); + if let Some((rank, new_id)) = merges.get(&new_pair) { + queue.push(Merge { + pos: top.pos, + rank: *rank, + new_id: *new_id, + }); + } + } } } - - // Insert the new pair formed with the next symbol - let next = current.next as usize; - if next < self.symbols.len() { - let next_symbol = self.symbols[next]; - let new_pair = (current.c, next_symbol.c); - if let Some((rank, new_id)) = merges.get(&new_pair) { - queue.push(Merge { - pos: top.pos, - rank: *rank, - new_id: *new_id, - }); - } - } - } - } + }); + }); // Filter out the removed symbols self.symbols.retain(|s| s.len != 0); diff --git a/tokenizers/src/models/unigram/model.rs b/tokenizers/src/models/unigram/model.rs index 3a9a6bddbd..89e5da55aa 100644 --- a/tokenizers/src/models/unigram/model.rs +++ b/tokenizers/src/models/unigram/model.rs @@ -233,19 +233,19 @@ impl Unigram { return Ok(vec![]); } if self.alpha.is_none() || self.alpha == Some(0.0) { - if let Some(result) = self.cache.get(sentence) { - Ok(result.to_vec()) + let mut result = Vec::new(); + if self.cache.get_into(sentence, &mut result) { + return Ok(result); + } + let result = if self.is_optimized { + self.encode_optimized(sentence)? } else { - let result = if self.is_optimized { - self.encode_optimized(sentence)? - } else { - self.encode_unoptimized(sentence)? - }; - if sentence.len() < MAX_LENGTH { - self.cache.set(sentence.to_owned(), result.clone()); - } - Ok(result) + self.encode_unoptimized(sentence)? + }; + if sentence.len() < MAX_LENGTH { + self.cache.set(sentence.to_owned(), result.clone()); } + Ok(result) } else { let result = self.encode_unoptimized(sentence)?; Ok(result) diff --git a/tokenizers/src/pre_tokenizers/byte_level.rs b/tokenizers/src/pre_tokenizers/byte_level.rs index 8bc0f30af0..4295ecef7e 100644 --- a/tokenizers/src/pre_tokenizers/byte_level.rs +++ b/tokenizers/src/pre_tokenizers/byte_level.rs @@ -48,6 +48,17 @@ static BYTES_CHAR: LazyLock> = LazyLock::new(bytes_char); static CHAR_BYTES: LazyLock> = LazyLock::new(|| bytes_char().into_iter().map(|(c, b)| (b, c)).collect()); +/// Flat lookup table: byte value → unicode char. Eliminates HashMap lookup +/// in the byte-level encoding hot path. +static BYTE_TO_CHAR: LazyLock<[char; 256]> = LazyLock::new(|| { + let map = bytes_char(); + let mut table = ['\0'; 256]; + for (b, c) in &map { + table[*b as usize] = *c; + } + table +}); + #[derive(Copy, Clone, Debug, PartialEq, Eq)] /// Provides all the necessary steps to handle the BPE tokenization at the byte-level. Takes care /// of all the required processing steps to transform a UTF-8 string as needed before and after the @@ -131,15 +142,15 @@ impl PreTokenizer for ByteLevel { })?; pretokenized.normalize(|normalized| { let s = normalized.get(); + let table = &*BYTE_TO_CHAR; let mut transformations: Vec<(char, isize)> = Vec::with_capacity(s.len()); for (i, cur_char) in s.char_indices() { let size = cur_char.len_utf8(); - transformations.extend( - s.as_bytes()[i..i + size] - .iter() - .enumerate() - .map(|(i, b)| (BYTES_CHAR[b], isize::from(i > 0))), - ); + let bytes = &s.as_bytes()[i..i + size]; + transformations.push((table[bytes[0] as usize], 0)); + for &b in &bytes[1..] { + transformations.push((table[b as usize], 1)); + } } normalized.transform(transformations, 0); Ok(()) diff --git a/tokenizers/src/utils/cache.rs b/tokenizers/src/utils/cache.rs index 15c6b65f18..aa8a7067bc 100644 --- a/tokenizers/src/utils/cache.rs +++ b/tokenizers/src/utils/cache.rs @@ -1,6 +1,6 @@ -use ahash::AHashMap; +use rustc_hash::FxHashMap; use std::borrow::Borrow; -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::sync::RwLock; /// The default capacity for a `BPE`'s internal cache. @@ -9,25 +9,132 @@ pub static DEFAULT_CACHE_CAPACITY: usize = 10_000; /// Strings that are too long have minimal chances to cache hit anyway pub static MAX_LENGTH: usize = 256; -/// Provides a simple multithread cache to speed up BPE tokenization that will try to read values -/// concurrently but won't block if another thread is writing. -/// The goal is clearly not the accuracy of the content, both get and set -/// are not guaranteed to actually get or set. -#[derive(Debug)] +/// Number of shards in the sharded cache. +const SHARED_CACHE_SHARDS: usize = 64; + +/// Trait for copying data from a reference into a mutable buffer. +/// Used by the cache to avoid cloning on cache hits. +pub trait ExtendFromRef { + fn extend_from_ref(&mut self, other: &Self); +} + +impl ExtendFromRef for Vec { + fn extend_from_ref(&mut self, other: &Self) { + self.extend_from_slice(other); + } +} + +#[inline] +fn fx_hash(key: &K) -> u64 { + let mut h = rustc_hash::FxHasher::default(); + key.hash(&mut h); + h.finish() +} + +struct ShardedMap { + shards: Vec>>, + per_shard_capacity: usize, +} + +impl ShardedMap { + fn new(total_capacity: usize) -> Self { + let per_shard = total_capacity.div_ceil(SHARED_CACHE_SHARDS).max(1); + let shards = (0..SHARED_CACHE_SHARDS) + .map(|_| { + RwLock::new(FxHashMap::with_capacity_and_hasher( + per_shard, + Default::default(), + )) + }) + .collect(); + ShardedMap { + shards, + per_shard_capacity: per_shard, + } + } + + #[inline] + fn shard_for(key: &Q) -> usize { + let h = fx_hash(key); + (h >> 48) as usize % SHARED_CACHE_SHARDS + } + + fn get_into(&self, key: &Q, out: &mut V) -> bool + where + K: Borrow, + Q: Hash + Eq + ?Sized, + { + let idx = Self::shard_for(key); + let shard = &self.shards[idx]; + if let Ok(guard) = shard.try_read() { + if let Some(value) = guard.get(key) { + out.extend_from_ref(value); + return true; + } + } + false + } + + fn set(&self, key: K, value: V) { + let idx = Self::shard_for(&key); + let shard = &self.shards[idx]; + if let Ok(guard) = shard.try_read() { + if guard.len() >= self.per_shard_capacity { + return; + } + } else { + return; + } + if let Ok(mut guard) = shard.try_write() { + if guard.len() < self.per_shard_capacity { + guard.insert(key, value); + } + } + } + + fn clear(&self) { + for shard in &self.shards { + if let Ok(mut guard) = shard.write() { + guard.clear(); + } + } + } +} + +// --------------------------------------------------------------------------- +// Public Cache +// --------------------------------------------------------------------------- + +/// Sharded cache for fast concurrent tokenization lookups. +/// +/// Uses 64 shards with per-shard `RwLock` to minimize lock +/// contention across threads. FxHash provides fast, non-cryptographic hashing +/// suited to the small keys used in tokenization caches. pub(crate) struct Cache where K: Eq + Hash + Clone, - V: Clone, + V: ExtendFromRef, { - map: RwLock>, + map: ShardedMap, pub capacity: usize, } -// We dont really care about Cache comparison, so let's make them always equal +impl std::fmt::Debug for Cache +where + K: Eq + Hash + Clone, + V: ExtendFromRef, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Cache") + .field("capacity", &self.capacity) + .finish() + } +} + impl PartialEq for Cache where K: Eq + Hash + Clone, - V: Clone, + V: ExtendFromRef, { fn eq(&self, _other: &Cache) -> bool { true @@ -37,7 +144,7 @@ where impl Default for Cache where K: Eq + Hash + Clone, - V: Clone, + V: ExtendFromRef, { fn default() -> Self { Self::new(DEFAULT_CACHE_CAPACITY) @@ -47,12 +154,14 @@ where impl Cache where K: Eq + Hash + Clone, - V: Clone, + V: ExtendFromRef, { /// Create new `Cache` with the given capacity. pub(crate) fn new(capacity: usize) -> Self { - let map = RwLock::new(AHashMap::with_capacity(capacity)); - Cache { map, capacity } + Cache { + map: ShardedMap::new(capacity), + capacity, + } } /// Create a fresh `Cache` with the same configuration. @@ -62,67 +171,25 @@ where /// Clear the cache. pub(crate) fn clear(&self) { - self.map.write().unwrap().clear(); + self.map.clear(); } - #[allow(dead_code)] - pub(crate) fn get_values<'a, I, Q>(&self, keys_iter: I) -> Option>> - where - I: Iterator, - K: Borrow, - Q: Hash + Eq + ?Sized + 'a, - { - if let Ok(ref mut cache) = self.map.try_read() { - Some(keys_iter.map(|k| cache.get(k).cloned()).collect()) - } else { - None - } - } - - pub(crate) fn get(&self, key: &Q) -> Option + /// Get a value from the cache, extending the output buffer. + /// Returns true if the key was found, false otherwise. + pub(crate) fn get_into(&self, key: &Q, out: &mut V) -> bool where K: Borrow, Q: Hash + Eq + ?Sized, { - if let Ok(ref mut cache) = self.map.try_read() { - cache.get(key).cloned() - } else { - None - } - } - - pub(crate) fn set_values(&self, entries: I) - where - I: IntoIterator, - { - // Before trying to acquire a write lock, we check if we are already at - // capacity with a read handler. - if let Ok(cache) = self.map.try_read() { - if cache.len() >= self.capacity { - // At capacity, so do nothing. - return; - } - } else { - // If we couldn't acquire a read handle then we probably won't be able to acquire - // a write handle one quadrillionth of a second later. - return; - } - - // Not at capacity, so try acquiring a write handle. - if let Ok(mut cache) = self.map.try_write() { - let free = self.capacity - cache.len(); - cache.extend(entries.into_iter().take(free)); - } + self.map.get_into(key, out) } pub(crate) fn set(&self, key: K, value: V) { - self.set_values(std::iter::once((key, value))) + self.map.set(key, value); } pub(crate) fn resize(&mut self, capacity: usize) { self.capacity = capacity; - if let Ok(mut cache) = self.map.try_write() { - cache.shrink_to(capacity); - } + self.map = ShardedMap::new(capacity); } }