Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tokenizers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ getrandom = { version = "0.3" }
esaxx-rs = { version = "0.1.10", default-features = false, features = [] }
monostate = "0.1.12"
ahash = { version = "0.8.11", features = ["serde"] }
rustc-hash = "2"
dary_heap = { version = "0.3.6", features = ["serde"] }
compact_str = { version = "0.9", features = ["serde"] }

Expand Down
75 changes: 75 additions & 0 deletions tokenizers/src/models/bpe/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
//! [Byte Pair Encoding](https://www.aclweb.org/anthology/P16-1162/) model.
use rustc_hash::FxHashMap;
use std::{iter, mem};

mod model;
Expand All @@ -8,6 +9,80 @@ mod word;

type Pair = (u32, u32);

/// Packs a `(u32, u32)` pair into a single `u64` for faster hashing.
#[inline]
fn pack_pair(pair: &Pair) -> u64 {
(pair.0 as u64) << 32 | pair.1 as u64
}

/// Unpacks a `u64` back into a `(u32, u32)` pair.
#[inline]
fn unpack_pair(packed: u64) -> Pair {
((packed >> 32) as u32, packed as u32)
}

/// A merge-lookup map that packs `(u32, u32)` pair keys into single `u64` values
/// for faster hashing (single FxHash multiply instead of hashing two fields).
///
/// Values are `(rank, new_id)` tuples.
#[derive(Clone, Debug)]
pub(crate) struct MergeMap {
inner: FxHashMap<u64, (u32, u32)>,
}

impl MergeMap {
#[allow(dead_code)]
pub fn new() -> Self {
MergeMap {
inner: FxHashMap::default(),
}
}

pub fn with_capacity(cap: usize) -> Self {
MergeMap {
inner: FxHashMap::with_capacity_and_hasher(cap, Default::default()),
}
}

#[inline]
/// Get `(rank, new_id)` for a given `Pair` in the map.
pub fn get(&self, pair: &Pair) -> Option<&(u32, u32)> {
Comment thread
michaelfeil marked this conversation as resolved.
self.inner.get(&pack_pair(pair))
}

/// Insert `(rank, new_id)` for a given `Pair` in the map.
pub fn insert(&mut self, pair: Pair, value: (u32, u32)) -> Option<(u32, u32)> {
Comment thread
michaelfeil marked this conversation as resolved.
self.inner.insert(pack_pair(&pair), value)
}

pub fn len(&self) -> usize {
self.inner.len()
}

/// Iterate over `(Pair, &(rank, new_id))`.
pub fn iter(&self) -> impl Iterator<Item = (Pair, &(u32, u32))> {
self.inner.iter().map(|(k, v)| (unpack_pair(*k), v))
}
}

impl PartialEq for MergeMap {
fn eq(&self, other: &Self) -> bool {
self.inner == other.inner
}
}

impl std::iter::FromIterator<(Pair, (u32, u32))> for MergeMap {
fn from_iter<I: IntoIterator<Item = (Pair, (u32, u32))>>(iter: I) -> Self {
let iter = iter.into_iter();
let (lo, _) = iter.size_hint();
let mut map = MergeMap::with_capacity(lo);
for (pair, val) in iter {
map.insert(pair, val);
}
map
}
}

/// Errors that can be encountered while using or constructing a `BPE` model.
#[derive(thiserror::Error, Debug)]
pub enum Error {
Expand Down
40 changes: 25 additions & 15 deletions tokenizers/src/models/bpe/model.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use super::{super::OrderedVocabIter, trainer::BpeTrainer, Error, Pair, Word};
use super::{super::OrderedVocabIter, trainer::BpeTrainer, Error, MergeMap, Pair, Word};
use crate::tokenizer::{Model, Result, Token};
use crate::utils::cache::{Cache, DEFAULT_CACHE_CAPACITY, MAX_LENGTH};
use crate::utils::iter::ResultShunt;
use ahash::AHashMap;
use serde_json::Value;
use std::borrow::Cow;
use std::cell::RefCell;

use std::collections::HashMap;
use std::str::from_utf8_unchecked;
Expand All @@ -15,9 +16,12 @@ use std::{
path::{Path, PathBuf},
};

thread_local! {
static TL_WORD: RefCell<Word> = RefCell::new(Word::with_capacity(64));
}

pub type Vocab = AHashMap<String, u32>;
type VocabR = AHashMap<u32, String>;
pub type MergeMap = AHashMap<Pair, (u32, u32)>;
pub type Merges = Vec<(String, String)>;

struct Config {
Expand Down Expand Up @@ -495,17 +499,23 @@ impl BPE {
)]);
}
}
if let Some(ref hit) = self.cache.as_ref().and_then(|c| c.get(sequence)) {
return Ok(self.word_to_tokens(hit).collect());
}
let word = self.merge_word(sequence)?;
let ret = self.word_to_tokens(&word).collect();
if let Some(ref cache) = self.cache {
if sequence.len() < MAX_LENGTH {
cache.set(sequence.to_owned(), word);
TL_WORD.with(|w| {
let mut word = w.borrow_mut();
word.clear();
if let Some(ref cache) = self.cache {
if cache.get_into(sequence, &mut word) {
return Ok(self.word_to_tokens(&word).collect());
}
}
}
Ok(ret)
let word = self.merge_word(sequence)?;
let ret = self.word_to_tokens(&word).collect();
if let Some(ref cache) = self.cache {
if sequence.len() < MAX_LENGTH {
cache.set(sequence.to_owned(), word);
}
}
Ok(ret)
})
}
}

Expand Down Expand Up @@ -566,12 +576,12 @@ impl Model for BPE {
.iter()
.collect();
let mut merges_file = File::create(&merges_path)?;
let mut merges: Vec<(&Pair, &u32)> = self
let mut merges: Vec<(Pair, u32)> = self
.merges
.iter()
.map(|(pair, (rank, _))| (pair, rank))
.map(|(pair, (rank, _))| (pair, *rank))
.collect();
merges.sort_unstable_by_key(|k| *k.1);
merges.sort_unstable_by_key(|k| k.1);
merges_file.write_all(b"#version: 0.2\n")?;
merges_file.write_all(
&merges
Expand Down
6 changes: 3 additions & 3 deletions tokenizers/src/models/bpe/serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ impl Serialize for BPE {
model.serialize_field("ignore_merges", &self.ignore_merges)?;

// Then the large ones
let mut merges: Vec<(&Pair, &u32)> = self
let mut merges: Vec<(Pair, u32)> = self
.merges
.iter()
.map(|(pair, (rank, _))| (pair, rank))
.map(|(pair, (rank, _))| (pair, *rank))
.collect();
merges.sort_unstable_by_key(|k| *k.1);
merges.sort_unstable_by_key(|k| k.1);
let merges = merges
.into_iter()
.map(|(pair, _)| (self.vocab_r[&pair.0].clone(), self.vocab_r[&pair.1].clone()))
Expand Down
5 changes: 3 additions & 2 deletions tokenizers/src/models/bpe/trainer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,8 @@ impl Trainer for BpeTrainer {

#[cfg(test)]
mod tests {
use super::{BpeTrainer, Pair, BPE};
use super::{BpeTrainer, BPE};
use crate::models::bpe::MergeMap;
use ahash::AHashMap;
use compact_str::CompactString;

Expand Down Expand Up @@ -744,7 +745,7 @@ mod tests {
// where 'rank' determines the order in which this merge will be applied during
// tokenization, and 'id' is the vocab id of the symbol resulting from merging
// the pair of symbols in the corresponding key.
let expected_merges: AHashMap<Pair, (u32, u32)> = [
let expected_merges: MergeMap = [
((17, 11), (0, 22)), // 'r' + 'e' -> 're'
((8, 22), (1, 23)), // 'a' + 're' -> 'are'
((13, 18), (2, 24)), // 'i' + 's' -> 'is'
Expand Down
Loading
Loading