Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tokenizers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ getrandom = { version = "0.3" }
esaxx-rs = { version = "0.1.10", default-features = false, features=[]}
monostate = "0.1.12"
ahash = { version = "0.8.11", features = ["serde"] }
rustc-hash = "2"
dary_heap = { version = "0.3.6", features = ["serde"] }
compact_str = { version = "0.9", features = ["serde"] }

Expand Down
35 changes: 35 additions & 0 deletions tokenizers/benches/llama3_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ mod common;
use common::{iter_bench_encode, iter_bench_encode_batch, iter_bench_train};
use criterion::{Criterion, Throughput};
use std::hint::black_box;
use std::sync::Arc;
use tokenizers::{
models::{bpe::BpeTrainerBuilder, TrainerWrapper},
EncodeInput, Tokenizer,
Expand Down Expand Up @@ -43,6 +44,40 @@ pub fn llama3(c: &mut Criterion) {
group.bench_function("llama3-batch", |b| {
b.iter_custom(|iters| iter_bench_encode_batch(iters, &tokenizer, &batches))
});
// Concurrent long-context: N threads each encode a different large input (80k chars)
let all_lines: Vec<&str> = data.lines().collect();
let lines_per_thread = 1000;
let tokenizer_arc = Arc::new(tokenizer.clone());
for num_threads in [1, 2, 4, 8] {
let inputs: Vec<String> = (0..num_threads)
.map(|i| {
let start = i * lines_per_thread;
all_lines[start..start + lines_per_thread].join("\n")
})
.collect();
let total_bytes: usize = inputs.iter().map(|s| s.len()).sum();
let tok = tokenizer_arc.clone();
group.throughput(Throughput::Bytes(total_bytes as u64));
group.bench_function(format!("llama3-concurrent-long-{num_threads}t"), move |b| {
b.iter(|| {
std::thread::scope(|s| {
let handles: Vec<_> = inputs
.iter()
.map(|input| {
let tok = &tok;
s.spawn(move || {
black_box(tok.encode(black_box(input.as_str()), false).unwrap())
})
})
.collect();
for h in handles {
h.join().unwrap();
}
});
})
});
}

let mut trainer: TrainerWrapper = BpeTrainerBuilder::default()
.show_progress(false)
.build()
Expand Down
73 changes: 73 additions & 0 deletions tokenizers/src/models/bpe/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
//! [Byte Pair Encoding](https://www.aclweb.org/anthology/P16-1162/) model.
use rustc_hash::FxHashMap;
use std::{iter, mem};

mod model;
Expand All @@ -8,6 +9,78 @@ mod word;

type Pair = (u32, u32);

/// Packs a `(u32, u32)` pair into a single `u64` for faster hashing.
#[inline]
fn pack_pair(pair: &Pair) -> u64 {
(pair.0 as u64) << 32 | pair.1 as u64
}

/// Unpacks a `u64` back into a `(u32, u32)` pair.
#[inline]
fn unpack_pair(packed: u64) -> Pair {
((packed >> 32) as u32, packed as u32)
}

/// A merge-lookup map that packs `(u32, u32)` pair keys into single `u64` values
/// for faster hashing (single FxHash multiply instead of hashing two fields).
///
/// Values are `(rank, new_id)` tuples.
#[derive(Clone, Debug)]
pub(crate) struct MergeMap {
inner: FxHashMap<u64, (u32, u32)>,
}

impl MergeMap {
#[allow(dead_code)]
pub fn new() -> Self {
MergeMap {
inner: FxHashMap::default(),
}
}

pub fn with_capacity(cap: usize) -> Self {
MergeMap {
inner: FxHashMap::with_capacity_and_hasher(cap, Default::default()),
}
}

#[inline]
pub fn get(&self, pair: &Pair) -> Option<&(u32, u32)> {
Comment thread
michaelfeil marked this conversation as resolved.
self.inner.get(&pack_pair(pair))
}

pub fn insert(&mut self, pair: Pair, value: (u32, u32)) -> Option<(u32, u32)> {
Comment thread
michaelfeil marked this conversation as resolved.
self.inner.insert(pack_pair(&pair), value)
}

pub fn len(&self) -> usize {
self.inner.len()
}

/// Iterate over `(Pair, &(rank, new_id))`.
pub fn iter(&self) -> impl Iterator<Item = (Pair, &(u32, u32))> {
self.inner.iter().map(|(k, v)| (unpack_pair(*k), v))
}
}

impl PartialEq for MergeMap {
fn eq(&self, other: &Self) -> bool {
self.inner == other.inner
}
}

impl std::iter::FromIterator<(Pair, (u32, u32))> for MergeMap {
fn from_iter<I: IntoIterator<Item = (Pair, (u32, u32))>>(iter: I) -> Self {
let iter = iter.into_iter();
let (lo, _) = iter.size_hint();
let mut map = MergeMap::with_capacity(lo);
for (pair, val) in iter {
map.insert(pair, val);
}
map
}
}

/// Errors that can be encountered while using or constructing a `BPE` model.
#[derive(thiserror::Error, Debug)]
pub enum Error {
Expand Down
9 changes: 4 additions & 5 deletions tokenizers/src/models/bpe/model.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{super::OrderedVocabIter, trainer::BpeTrainer, Error, Pair, Word};
use super::{super::OrderedVocabIter, trainer::BpeTrainer, Error, MergeMap, Pair, Word};
use crate::tokenizer::{Model, Result, Token};
use crate::utils::cache::{Cache, DEFAULT_CACHE_CAPACITY, MAX_LENGTH};
use crate::utils::iter::ResultShunt;
Expand All @@ -16,7 +16,6 @@ use std::{

pub type Vocab = AHashMap<String, u32>;
type VocabR = AHashMap<u32, String>;
pub type MergeMap = AHashMap<Pair, (u32, u32)>;
pub type Merges = Vec<(String, String)>;

struct Config {
Expand Down Expand Up @@ -553,12 +552,12 @@ impl Model for BPE {
.iter()
.collect();
let mut merges_file = File::create(&merges_path)?;
let mut merges: Vec<(&Pair, &u32)> = self
let mut merges: Vec<(Pair, u32)> = self
.merges
.iter()
.map(|(pair, (rank, _))| (pair, rank))
.map(|(pair, (rank, _))| (pair, *rank))
.collect();
merges.sort_unstable_by_key(|k| *k.1);
merges.sort_unstable_by_key(|k| k.1);
merges_file.write_all(b"#version: 0.2\n")?;
merges_file.write_all(
&merges
Expand Down
6 changes: 3 additions & 3 deletions tokenizers/src/models/bpe/serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ impl Serialize for BPE {
model.serialize_field("ignore_merges", &self.ignore_merges)?;

// Then the large ones
let mut merges: Vec<(&Pair, &u32)> = self
let mut merges: Vec<(Pair, u32)> = self
.merges
.iter()
.map(|(pair, (rank, _))| (pair, rank))
.map(|(pair, (rank, _))| (pair, *rank))
.collect();
merges.sort_unstable_by_key(|k| *k.1);
merges.sort_unstable_by_key(|k| k.1);
let merges = merges
.into_iter()
.map(|(pair, _)| (self.vocab_r[&pair.0].clone(), self.vocab_r[&pair.1].clone()))
Expand Down
5 changes: 3 additions & 2 deletions tokenizers/src/models/bpe/trainer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,8 @@ impl Trainer for BpeTrainer {

#[cfg(test)]
mod tests {
use super::{BpeTrainer, Pair, BPE};
use super::{BpeTrainer, BPE};
use crate::models::bpe::MergeMap;
use ahash::AHashMap;
use compact_str::CompactString;

Expand Down Expand Up @@ -707,7 +708,7 @@ mod tests {
// where 'rank' determines the order in which this merge will be applied during
// tokenization, and 'id' is the vocab id of the symbol resulting from merging
// the pair of symbols in the corresponding key.
let expected_merges: AHashMap<Pair, (u32, u32)> = [
let expected_merges: MergeMap = [
((17, 11), (0, 22)), // 'r' + 'e' -> 're'
((8, 22), (1, 23)), // 'a' + 're' -> 'are'
((13, 18), (2, 24)), // 'i' + 's' -> 'is'
Expand Down
Loading