diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 8e282fba28..288b04aefd 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -20,6 +20,7 @@ use std::{ use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; +use crate::utils::batch::{BatchWorkQueue, ResultVec, TakeVec}; use crate::utils::iter::ResultShunt; use crate::utils::parallelism::*; use crate::utils::progress::{ProgressBar, ProgressStyle}; @@ -1300,7 +1301,11 @@ where PP: PostProcessor + Send + Sync, D: Decoder + Send + Sync, { - /// Encode all the sentences in parallel, using multiple threads + /// Encode all the sentences in parallel, using multiple threads. + /// + /// Uses a lock-free work queue with cache-line-sized windows instead of + /// rayon's `bridge_producer_consumer`, eliminating its synchronization + /// overhead at higher thread counts. pub fn encode_batch<'s, E>( &self, inputs: Vec, @@ -1309,13 +1314,10 @@ where where E: Into> + Send, { - let mut encodings = inputs - .into_maybe_par_iter() - .map(|input| self.encode(input, add_special_tokens)) - .collect::>>()?; + let mut encodings = + self.run_batch(inputs, |this, input| this.encode(input, add_special_tokens))?; if let Some(params) = &self.padding { - // We do the padding here to make sure we handle the batch padding pad_encodings(&mut encodings, params)?; } @@ -1332,20 +1334,22 @@ where where E: Into> + Send, { - let mut encodings = inputs - .into_maybe_par_iter() - .map(|input| self.encode_char_offsets(input, add_special_tokens)) - .collect::>>()?; + let mut encodings = self.run_batch(inputs, |this, input| { + this.encode_char_offsets(input, add_special_tokens) + })?; if let Some(params) = &self.padding { - // We do the padding here to make sure we handle the batch padding pad_encodings(&mut encodings, params)?; } Ok(encodings) } - /// Encode all the sentences in parallel, using multiple threads + /// Encode all the sentences in parallel, using multiple threads. + /// + /// Uses a lock-free work queue with cache-line-sized windows instead of + /// rayon's `bridge_producer_consumer`, eliminating its synchronization + /// overhead at higher thread counts. pub fn encode_batch_fast<'s, E>( &self, inputs: Vec, @@ -1354,19 +1358,79 @@ where where E: Into> + Send, { - let mut encodings = inputs - .into_maybe_par_iter() - .map(|input| self.encode_fast(input, add_special_tokens)) - .collect::>>()?; + let mut encodings = self.run_batch(inputs, |this, input| { + this.encode_fast(input, add_special_tokens) + })?; if let Some(params) = &self.padding { - // We do the padding here to make sure we handle the batch padding pad_encodings(&mut encodings, params)?; } Ok(encodings) } + /// Shared implementation for all batch encode variants. + /// + /// Distributes work items across threads using a lock-free atomic counter. + /// Each thread claims a dynamically-sized window of items, processes them, + /// and writes results directly to pre-allocated slots. + /// + /// Uses `rayon::scope` to run on the existing rayon thread pool, avoiding + /// the cost of creating/destroying OS threads on every call. + fn run_batch<'s, E, F>(&self, inputs: Vec, encode_fn: F) -> Result> + where + E: Into> + Send, + F: Fn(&Self, EncodeInput<'s>) -> Result + Sync, + { + let n = inputs.len(); + if n == 0 { + return Ok(vec![]); + } + + let parallelism = get_parallelism(); + let num_threads = if parallelism { + current_num_threads().min(n) + } else { + 1 + }; + + if num_threads <= 1 { + return inputs + .into_iter() + .map(|input| encode_fn(self, input.into())) + .collect::>>(); + } + + // Lock-free batch distribution: atomic counter hands out + // dynamically-sized windows of item indices to worker threads. + let inputs = TakeVec::new( + inputs + .into_iter() + .map(|e| e.into()) + .collect::>>(), + ); + let results: ResultVec> = ResultVec::new(n); + let queue = BatchWorkQueue::new(n, num_threads); + + rayon::scope(|s| { + for _ in 0..num_threads { + s.spawn(|_| { + while let Some((start, end)) = queue.claim_window() { + for i in start..end { + let input = inputs.take(i); + results.set(i, encode_fn(self, input)); + } + } + }); + } + }); + + results + .into_vec() + .into_iter() + .collect::>>() + } + /// Decode all sentences in parallel pub fn decode_batch( &self, diff --git a/tokenizers/src/utils/batch.rs b/tokenizers/src/utils/batch.rs new file mode 100644 index 0000000000..e0a3c0a6b7 --- /dev/null +++ b/tokenizers/src/utils/batch.rs @@ -0,0 +1,299 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Lock-free batch work distribution with dynamic window sizing. +//! +//! Replaces rayon's parallel iteration for batch encode with a simpler +//! mechanism: a single atomic counter hands out contiguous windows of +//! item indices to worker threads running on rayon's persistent thread +//! pool. The only cross-thread synchronization on the hot path is the +//! `AtomicUsize::fetch_add` that claims each window. +//! +//! ## Cache-line / loop-tiling rationale +//! +//! Shared-memory parallel loops are bottlenecked by the cache coherence +//! protocol when two cores alternate writes to the same cache line: the +//! line "ping-pongs" between their private L1d caches, each transfer +//! costing dozens of cycles. To avoid that, every line should be filled +//! by one producer core, drained (or no longer needed), and only then +//! touched by a different core. This is the cache-aware equivalent of +//! loop tiling / blocking. +//! +//! The work queue enforces this three ways: +//! +//! 1. The counter itself lives on its own 64-byte cache line +//! (`#[repr(C, align(64))]` on `AlignedCounter`). A worker's +//! `fetch_add` does not evict any neighbouring data, and reads of the +//! counter do not pull input or result payloads into the core's L1d. +//! +//! 2. Each window is a contiguous run of `window_size` indices, so every +//! worker owns a run of adjacent slots for the duration of one +//! window. With `MAX_WINDOW_SIZE = 8`, a window covers roughly +//! `8 * sizeof(slot)` bytes -- for `Option` (~48 B) that +//! is ~6 cache lines; for `Option>` (multi-line per +//! slot) it is even more. So within one window, a worker writes +//! several whole cache lines before any other worker comes near them. +//! +//! 3. Each slot has its own `UnsafeCell` +//! (`Vec>>`). `UnsafeCell` is +//! `#[repr(transparent)]` so the heap layout is identical to a plain +//! `Vec>` (no padding, no indirection), but concurrent +//! accesses to different indices never materialise a shared `&mut` +//! reference to the enclosing `Vec` (which would be UB, regardless of +//! which element each access ultimately reached). +//! +//! At window boundaries a single cache line can be shared between two +//! successive windows when the slot size does not divide 64 bytes. That +//! is a *sequential* handoff (window N finishes writes; window N+1 then +//! reads/writes), not a concurrent ping-pong. +//! +//! ## Window sizing +//! +//! `window_size = ceil(total / (num_threads * WINDOWS_PER_THREAD))`, +//! clamped to `[1, MAX_WINDOW_SIZE]`. +//! +//! - `WINDOWS_PER_THREAD = 4` keeps several windows per thread so a +//! slow worker on its last item does not stall the whole batch. +//! - `MAX_WINDOW_SIZE = 8` caps per-claim atomic latency and keeps the +//! per-window memory footprint small enough to fit comfortably in L1d. +//! +//! Example: 100 items / 16 threads yields window_size = 2 (50 windows); +//! 10000 items / 16 threads yields window_size = 8 (1250 windows). + +use std::cell::UnsafeCell; +use std::sync::atomic::{AtomicUsize, Ordering}; + +/// Minimum number of windows each thread should get for load balancing. +const WINDOWS_PER_THREAD: usize = 4; + +/// Maximum window size (items per atomic claim). Larger values reduce +/// atomic contention but worsen tail-latency from uneven last windows. +const MAX_WINDOW_SIZE: usize = 8; + +/// Cache-line-aligned atomic counter. +/// Ensures the counter does not share a cache line with any other data. +#[repr(C, align(64))] +struct AlignedCounter(AtomicUsize); + +/// Lock-free work distributor. +/// +/// Workers atomically claim non-overlapping windows of item indices. +/// The window size is chosen dynamically based on `total` and +/// `num_threads` so that every thread gets several windows of work. +/// The counter is on its own cache line so claiming work does not +/// contend with result writes. +pub(crate) struct BatchWorkQueue { + next: AlignedCounter, + total: usize, + window_size: usize, +} + +impl BatchWorkQueue { + /// Create a new queue distributing `total` items across `num_threads`. + /// + /// The window size is chosen to give each thread at least + /// `WINDOWS_PER_THREAD` windows, capped at `MAX_WINDOW_SIZE`. + pub(crate) fn new(total: usize, num_threads: usize) -> Self { + let target_windows = num_threads.saturating_mul(WINDOWS_PER_THREAD).max(1); + let window_size = total.div_ceil(target_windows).clamp(1, MAX_WINDOW_SIZE); + Self { + next: AlignedCounter(AtomicUsize::new(0)), + total, + window_size, + } + } + + /// Claim the next window of work items. + /// Returns `Some((start, end))` half-open range, or `None` when all + /// items have been claimed. + pub(crate) fn claim_window(&self) -> Option<(usize, usize)> { + let start = self.next.0.fetch_add(self.window_size, Ordering::Relaxed); + if start >= self.total { + return None; + } + Some((start, (start + self.window_size).min(self.total))) + } +} + +/// A `Vec` whose elements can each be *taken* exactly once from any thread. +/// +/// The `BatchWorkQueue` guarantees that no two threads access the same +/// index, so no synchronization is needed beyond the queue itself. +/// +/// Layout: each slot has its own `UnsafeCell>`. Because +/// `UnsafeCell` is `#[repr(transparent)]` over `U`, this heap layout +/// is byte-identical to a plain `Vec>`: no added padding, +/// identical slot alignment, identical contiguous packing. The only +/// difference is that `self.0[i].get()` gives a raw `*mut Option` +/// pointing straight at slot `i`, without ever materialising a +/// `&mut Vec>` (which would alias the enclosing container and +/// be UB when two threads touch any distinct indices concurrently). +pub(crate) struct TakeVec(Vec>>); + +// SAFETY: callers guarantee each index is accessed by at most one thread; +// `take` produces a raw pointer to a single slot's `UnsafeCell` without +// aliasing the surrounding `Vec`. +unsafe impl Sync for TakeVec {} + +impl TakeVec { + /// Wrap a `Vec` so items can be taken by index. + pub(crate) fn new(items: Vec) -> Self { + Self( + items + .into_iter() + .map(|t| UnsafeCell::new(Some(t))) + .collect(), + ) + } + + /// Take the item at `index`, leaving `None` in its place. + /// Panics if the item was already taken. + pub(crate) fn take(&self, index: usize) -> T { + // SAFETY: the `BatchWorkQueue` guarantees that each `index` is passed + // to `take` by at most one thread. `self.0[index].get()` returns a + // raw pointer to that slot's `Option`; reborrowing it as `&mut` + // does not alias any sibling slot's data. + unsafe { + (*self.0[index].get()) + .take() + .expect("batch item already taken") + } + } +} + +/// A `Vec>` where each slot is written exactly once from any +/// thread. +/// +/// The `BatchWorkQueue` guarantees non-overlapping index access. +/// +/// Layout: same note as `TakeVec`. Each slot is a +/// `UnsafeCell>` (`#[repr(transparent)]` over `Option`), so +/// the heap layout is byte-identical to a plain `Vec>` +/// and `self.0[i].get()` yields a raw `*mut Option` to slot `i` +/// without materialising a `&mut Vec>`. +pub(crate) struct ResultVec(Vec>>); + +// SAFETY: callers guarantee each index is written by at most one thread; +// `set` produces a raw pointer to a single slot's `UnsafeCell` without +// aliasing the surrounding `Vec`. +unsafe impl Sync for ResultVec {} + +impl ResultVec { + /// Allocate `len` empty result slots. + pub(crate) fn new(len: usize) -> Self { + Self((0..len).map(|_| UnsafeCell::new(None)).collect()) + } + + /// Write a result to the slot at `index`. + pub(crate) fn set(&self, index: usize, value: T) { + // SAFETY: the `BatchWorkQueue` guarantees that each `index` is passed + // to `set` by at most one thread, so no other reference to this + // slot's `Option` exists concurrently. + unsafe { + *self.0[index].get() = Some(value); + } + } + + /// Consume self and return the results in order. + /// Panics if any slot was not written. + pub(crate) fn into_vec(self) -> Vec { + self.0 + .into_iter() + .map(|cell| cell.into_inner().expect("result slot was never written")) + .collect() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_batch_work_queue_single_thread() { + // 20 items, 1 thread => target 4 windows => window_size = 5. + let queue = BatchWorkQueue::new(20, 1); + let mut ranges = Vec::new(); + while let Some(range) = queue.claim_window() { + ranges.push(range); + } + assert_eq!(ranges.len(), 4); + assert_eq!(ranges[0], (0, 5)); + assert_eq!(ranges[1], (5, 10)); + assert_eq!(ranges[2], (10, 15)); + assert_eq!(ranges[3], (15, 20)); + } + + #[test] + fn test_batch_work_queue_many_threads() { + // 100 items, 16 threads => target 64 windows => window_size = 2. + let queue = BatchWorkQueue::new(100, 16); + let mut ranges = Vec::new(); + while let Some(range) = queue.claim_window() { + ranges.push(range); + } + assert_eq!(ranges.len(), 50); + assert_eq!(ranges[0], (0, 2)); + assert_eq!(ranges[49], (98, 100)); + } + + #[test] + fn test_batch_work_queue_window_capped() { + // 10000 items, 4 threads => target 16 windows => window_size = 625, + // but capped at MAX_WINDOW_SIZE (8). + let queue = BatchWorkQueue::new(10000, 4); + let mut count = 0; + while queue.claim_window().is_some() { + count += 1; + } + // 10000 / 8 = 1250 windows. + assert_eq!(count, 1250); + } + + #[test] + fn test_batch_work_queue_empty() { + let queue = BatchWorkQueue::new(0, 4); + assert!(queue.claim_window().is_none()); + } + + #[test] + fn test_take_vec() { + let tv = TakeVec::new(vec![10, 20, 30]); + assert_eq!(tv.take(1), 20); + assert_eq!(tv.take(0), 10); + assert_eq!(tv.take(2), 30); + } + + #[test] + fn test_result_vec() { + let rv = ResultVec::::new(3); + rv.set(2, 30); + rv.set(0, 10); + rv.set(1, 20); + assert_eq!(rv.into_vec(), vec![10, 20, 30]); + } + + #[test] + fn test_parallel_distribution() { + let n = 100; + let num_threads = 4; + let queue = BatchWorkQueue::new(n, num_threads); + let results = ResultVec::new(n); + + std::thread::scope(|s| { + for _ in 0..num_threads { + s.spawn(|| { + while let Some((start, end)) = queue.claim_window() { + for i in start..end { + results.set(i, i * 2); + } + } + }); + } + }); + + let v = results.into_vec(); + for (i, &item) in v.iter().enumerate() { + assert_eq!(item, i * 2); + } + } +} diff --git a/tokenizers/src/utils/mod.rs b/tokenizers/src/utils/mod.rs index c9450b3222..252a466e64 100644 --- a/tokenizers/src/utils/mod.rs +++ b/tokenizers/src/utils/mod.rs @@ -1,3 +1,4 @@ +pub(crate) mod batch; pub(crate) mod cache; #[cfg(feature = "http")] pub(crate) mod from_pretrained;