diff --git a/bindings/python/py_src/tokenizers/__init__.pyi b/bindings/python/py_src/tokenizers/__init__.pyi index 44f19b8a4..da5cf6cc2 100644 --- a/bindings/python/py_src/tokenizers/__init__.pyi +++ b/bindings/python/py_src/tokenizers/__init__.pyi @@ -1614,6 +1614,60 @@ class Tokenizer: :class:`~tokenizers.Encoding`: The final post-processed encoding """ pass + def post_process_tokens( + self, + /, + tokens: list[str], + pair: list[str] | None = None, + add_special_tokens: bool = True, + ) -> list[str]: + """ + Post-process a list of tokens (and optionally a pair) and return the processed tokens. + + This is a simplified interface that only handles the token strings, without the full + Encoding information. Useful for step-by-step tokenization. + + Args: + tokens (:obj:`List[str]`): + The main sequence of tokens + + pair (:obj:`List[str]`, `optional`): + An optional pair sequence of tokens + + add_special_tokens (:obj:`bool`, defaults to :obj:`True`): + Whether to add special tokens + + Returns: + :obj:`List[str]`: A list of tokens with special tokens added according to the post-processor + """ + ... + def post_process_ids( + self, + /, + ids: list[int], + pair: list[int] | None = None, + add_special_tokens: bool = True, + ) -> list[int]: + """ + Post-process a list of token IDs (and optionally a pair) and return the processed IDs. + + This is a simplified interface that only handles the token IDs, without the full + Encoding information. Useful for step-by-step tokenization. + + Args: + ids (:obj:`List[int]`): + The main sequence of token IDs + + pair (:obj:`List[int]`, `optional`): + An optional pair sequence of token IDs + + add_special_tokens (:obj:`bool`, defaults to :obj:`True`): + Whether to add special tokens + + Returns: + :obj:`List[int]`: A list of token IDs with special tokens added according to the post-processor + """ + ... @property def post_processor(self): diff --git a/bindings/python/py_src/tokenizers/processors/__init__.pyi b/bindings/python/py_src/tokenizers/processors/__init__.pyi index 0d49520c6..c38f4f1f8 100644 --- a/bindings/python/py_src/tokenizers/processors/__init__.pyi +++ b/bindings/python/py_src/tokenizers/processors/__init__.pyi @@ -117,6 +117,60 @@ class BertProcessing(PostProcessor): :class:`~tokenizers.Encoding`: The final encoding """ pass + def process_tokens( + self, + /, + tokens: list[str], + pair: list[str] | None = None, + add_special_tokens: bool = True, + ) -> list[str]: + """ + Process a list of tokens (and optionally a pair) and return the processed tokens. + + This is a simplified interface that only handles the token strings, without the full + Encoding information. Useful for step-by-step tokenization. + + Args: + tokens (:obj:`List[str]`): + The main sequence of tokens + + pair (:obj:`List[str]`, `optional`): + An optional pair sequence of tokens + + add_special_tokens (:obj:`bool`, defaults to :obj:`True`): + Whether to add special tokens + + Returns: + :obj:`List[str]`: A list of tokens with special tokens added + """ + ... + def process_ids( + self, + /, + ids: list[int], + pair: list[int] | None = None, + add_special_tokens: bool = True, + ) -> list[int]: + """ + Process a list of token IDs (and optionally a pair) and return the processed IDs. + + This is a simplified interface that only handles the token IDs, without the full + Encoding information. Useful for step-by-step tokenization. + + Args: + ids (:obj:`List[int]`): + The main sequence of token IDs + + pair (:obj:`List[int]`, `optional`): + An optional pair sequence of token IDs + + add_special_tokens (:obj:`bool`, defaults to :obj:`True`): + Whether to add special tokens + + Returns: + :obj:`List[int]`: A list of token IDs with special tokens added + """ + ... @property def sep(self): diff --git a/bindings/python/src/processors.rs b/bindings/python/src/processors.rs index 3973bd795..9dbab0739 100644 --- a/bindings/python/src/processors.rs +++ b/bindings/python/src/processors.rs @@ -94,6 +94,26 @@ impl PostProcessor for PyPostProcessor { self.processor .process_encodings(encodings, add_special_tokens) } + + fn process_tokens( + &self, + tokens: Vec, + pair_tokens: Option>, + add_special_tokens: bool, + ) -> tk::Result> { + self.processor + .process_tokens(tokens, pair_tokens, add_special_tokens) + } + + fn process_ids( + &self, + ids: Vec, + pair_ids: Option>, + add_special_tokens: bool, + ) -> tk::Result> { + self.processor + .process_ids(ids, pair_ids, add_special_tokens) + } } #[pymethods] @@ -165,6 +185,66 @@ impl PyPostProcessor { Ok(final_encoding.into()) } + /// Process a list of tokens (and optionally a pair) and return the processed tokens. + /// + /// This is a simplified interface that only handles the token strings, without the full + /// Encoding information. Useful for step-by-step tokenization. + /// + /// Args: + /// tokens (:obj:`List[str]`): + /// The main sequence of tokens + /// + /// pair (:obj:`List[str]`, `optional`): + /// An optional pair sequence of tokens + /// + /// add_special_tokens (:obj:`bool`, defaults to :obj:`True`): + /// Whether to add special tokens + /// + /// Returns: + /// :obj:`List[str]`: A list of tokens with special tokens added + #[pyo3(signature = (tokens, pair = None, add_special_tokens = true))] + #[pyo3(text_signature = "(self, tokens, pair=None, add_special_tokens=True)")] + fn process_tokens( + &self, + tokens: Vec, + pair: Option>, + add_special_tokens: bool, + ) -> PyResult> { + ToPyResult( + self.processor + .process_tokens(tokens, pair, add_special_tokens), + ) + .into() + } + + /// Process a list of token IDs (and optionally a pair) and return the processed IDs. + /// + /// This is a simplified interface that only handles the token IDs, without the full + /// Encoding information. Useful for step-by-step tokenization. + /// + /// Args: + /// ids (:obj:`List[int]`): + /// The main sequence of token IDs + /// + /// pair (:obj:`List[int]`, `optional`): + /// An optional pair sequence of token IDs + /// + /// add_special_tokens (:obj:`bool`, defaults to :obj:`True`): + /// Whether to add special tokens + /// + /// Returns: + /// :obj:`List[int]`: A list of token IDs with special tokens added + #[pyo3(signature = (ids, pair = None, add_special_tokens = true))] + #[pyo3(text_signature = "(self, ids, pair=None, add_special_tokens=True)")] + fn process_ids( + &self, + ids: Vec, + pair: Option>, + add_special_tokens: bool, + ) -> PyResult> { + ToPyResult(self.processor.process_ids(ids, pair, add_special_tokens)).into() + } + fn __repr__(&self) -> PyResult { crate::utils::serde_pyo3::repr(self) .map_err(|e| exceptions::PyException::new_err(e.to_string())) @@ -258,6 +338,56 @@ impl PostProcessor for PyPostProcessorTypeWrapper { }, } } + + fn process_tokens( + &self, + mut tokens: Vec, + mut pair_tokens: Option>, + add_special_tokens: bool, + ) -> tk::Result> { + match self { + PyPostProcessorTypeWrapper::Single(inner) => inner + .read() + .map_err(|_| PyException::new_err("RwLock synchronisation primitive is poisoned, cannot get subtype of PyPreTokenizer"))? + .process_tokens(tokens, pair_tokens, add_special_tokens), + PyPostProcessorTypeWrapper::Sequence(inner) => { + for processor in inner.iter() { + let result = processor + .read() + .map_err(|_| PyException::new_err("RwLock synchronisation primitive is poisoned, cannot get subtype of PyPreTokenizer"))? + .process_tokens(tokens, pair_tokens, add_special_tokens)?; + tokens = result; + pair_tokens = None; + } + Ok(tokens) + }, + } + } + + fn process_ids( + &self, + mut ids: Vec, + mut pair_ids: Option>, + add_special_tokens: bool, + ) -> tk::Result> { + match self { + PyPostProcessorTypeWrapper::Single(inner) => inner + .read() + .map_err(|_| PyException::new_err("RwLock synchronisation primitive is poisoned, cannot get subtype of PyPreTokenizer"))? + .process_ids(ids, pair_ids, add_special_tokens), + PyPostProcessorTypeWrapper::Sequence(inner) => { + for processor in inner.iter() { + let result = processor + .read() + .map_err(|_| PyException::new_err("RwLock synchronisation primitive is poisoned, cannot get subtype of PyPreTokenizer"))? + .process_ids(ids, pair_ids, add_special_tokens)?; + ids = result; + pair_ids = None; + } + Ok(ids) + }, + } + } } impl<'de> Deserialize<'de> for PyPostProcessorTypeWrapper { diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 0cd06985c..d699b60d1 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -1735,6 +1735,70 @@ impl PyTokenizer { .into() } + /// Post-process a list of tokens (and optionally a pair) and return the processed tokens. + /// + /// This is a simplified interface that only handles the token strings, without the full + /// Encoding information. Useful for step-by-step tokenization. + /// + /// Args: + /// tokens (:obj:`List[str]`): + /// The main sequence of tokens + /// + /// pair (:obj:`List[str]`, `optional`): + /// An optional pair sequence of tokens + /// + /// add_special_tokens (:obj:`bool`, defaults to :obj:`True`): + /// Whether to add special tokens + /// + /// Returns: + /// :obj:`List[str]`: A list of tokens with special tokens added according to the post-processor + #[pyo3(signature = (tokens, pair=None, add_special_tokens=true))] + #[pyo3(text_signature = "(self, tokens, pair=None, add_special_tokens=True)")] + fn post_process_tokens( + &self, + tokens: Vec, + pair: Option>, + add_special_tokens: bool, + ) -> PyResult> { + ToPyResult( + self.tokenizer + .post_process_tokens(tokens, pair, add_special_tokens), + ) + .into() + } + + /// Post-process a list of token IDs (and optionally a pair) and return the processed IDs. + /// + /// This is a simplified interface that only handles the token IDs, without the full + /// Encoding information. Useful for step-by-step tokenization. + /// + /// Args: + /// ids (:obj:`List[int]`): + /// The main sequence of token IDs + /// + /// pair (:obj:`List[int]`, `optional`): + /// An optional pair sequence of token IDs + /// + /// add_special_tokens (:obj:`bool`, defaults to :obj:`True`): + /// Whether to add special tokens + /// + /// Returns: + /// :obj:`List[int]`: A list of token IDs with special tokens added according to the post-processor + #[pyo3(signature = (ids, pair=None, add_special_tokens=true))] + #[pyo3(text_signature = "(self, ids, pair=None, add_special_tokens=True)")] + fn post_process_ids( + &self, + ids: Vec, + pair: Option>, + add_special_tokens: bool, + ) -> PyResult> { + ToPyResult( + self.tokenizer + .post_process_ids(ids, pair, add_special_tokens), + ) + .into() + } + /// The :class:`~tokenizers.models.Model` in use by the Tokenizer #[getter] fn get_model(&self, py: Python<'_>) -> PyResult> { diff --git a/bindings/python/tests/bindings/test_processors.py b/bindings/python/tests/bindings/test_processors.py index a7e0ae13e..6f5d8253e 100644 --- a/bindings/python/tests/bindings/test_processors.py +++ b/bindings/python/tests/bindings/test_processors.py @@ -29,6 +29,36 @@ def test_instantiate(self): BertProcessing, ) + def test_process_tokens(self): + processor = BertProcessing(("[SEP]", 102), ("[CLS]", 101)) + + # Single sequence + result = processor.process_tokens(["Hello", "world"]) + assert result == ["[CLS]", "Hello", "world", "[SEP]"] + + # With pair + result = processor.process_tokens(["Hello"], ["world"]) + assert result == ["[CLS]", "Hello", "[SEP]", "world", "[SEP]"] + + # Without special tokens + result = processor.process_tokens(["Hello", "world"], add_special_tokens=False) + assert result == ["Hello", "world"] + + def test_process_ids(self): + processor = BertProcessing(("[SEP]", 102), ("[CLS]", 101)) + + # Single sequence + result = processor.process_ids([10, 20]) + assert result == [101, 10, 20, 102] + + # With pair + result = processor.process_ids([10], [20]) + assert result == [101, 10, 102, 20, 102] + + # Without special tokens + result = processor.process_ids([10, 20], add_special_tokens=False) + assert result == [10, 20] + def test_processing(self): tokenizer = Tokenizer(BPE()) tokenizer.add_special_tokens(["[SEP]", "[CLS]"]) @@ -51,6 +81,36 @@ def test_instantiate(self): RobertaProcessing, ) + def test_process_tokens(self): + processor = RobertaProcessing(("", 1), ("", 0)) + + # Single sequence + result = processor.process_tokens(["Hello", "world"]) + assert result == ["", "Hello", "world", ""] + + # With pair (Roberta adds extra before pair) + result = processor.process_tokens(["Hello"], ["world"]) + assert result == ["", "Hello", "", "", "world", ""] + + # Without special tokens + result = processor.process_tokens(["Hello", "world"], add_special_tokens=False) + assert result == ["Hello", "world"] + + def test_process_ids(self): + processor = RobertaProcessing(("", 1), ("", 0)) + + # Single sequence + result = processor.process_ids([10, 20]) + assert result == [0, 10, 20, 1] + + # With pair (Roberta adds extra before pair) + result = processor.process_ids([10], [20]) + assert result == [0, 10, 1, 1, 20, 1] + + # Without special tokens + result = processor.process_ids([10, 20], add_special_tokens=False) + assert result == [10, 20] + def test_processing(self): tokenizer = Tokenizer(BPE()) tokenizer.add_special_tokens(["", ""]) @@ -193,6 +253,36 @@ def test_roberta_parity(self): template = tokenizer.encode("my name is john", "pair") assert original.ids == template.ids + def test_process_tokens(self): + processor = self.get_bert() + + # Single sequence + result = processor.process_tokens(["Hello", "world"]) + assert result == ["[CLS]", "Hello", "world", "[SEP]"] + + # With pair + result = processor.process_tokens(["Hello"], ["world"]) + assert result == ["[CLS]", "Hello", "[SEP]", "world", "[SEP]"] + + # Without special tokens + result = processor.process_tokens(["Hello", "world"], add_special_tokens=False) + assert result == ["Hello", "world"] + + def test_process_ids(self): + processor = self.get_bert() + + # Single sequence + result = processor.process_ids([10, 20]) + assert result == [1, 10, 20, 0] + + # With pair + result = processor.process_ids([10], [20]) + assert result == [1, 10, 0, 20, 0] + + # Without special tokens + result = processor.process_ids([10, 20], add_special_tokens=False) + assert result == [10, 20] + class TestSequenceProcessing: def test_sequence_processing(self): diff --git a/bindings/python/tests/bindings/test_tokenizer.py b/bindings/python/tests/bindings/test_tokenizer.py index 96b75b24d..fcd353004 100644 --- a/bindings/python/tests/bindings/test_tokenizer.py +++ b/bindings/python/tests/bindings/test_tokenizer.py @@ -572,6 +572,46 @@ def test_post_process(self): output = tokenizer.post_process(encoding, pair_encoding) assert output.tokens == ["my", "pair", "[PAD]", "[PAD]"] + def test_post_process_tokens(self): + from tokenizers.processors import BertProcessing + + tokenizer = Tokenizer(BPE()) + tokenizer.add_special_tokens(["[SEP]", "[CLS]"]) + tokenizer.add_tokens(["my", "name", "is", "john", "pair"]) + tokenizer.post_processor = BertProcessing(("[SEP]", 0), ("[CLS]", 1)) + + # Single sequence + result = tokenizer.post_process_tokens(["my", "name"]) + assert result == ["[CLS]", "my", "name", "[SEP]"] + + # With pair + result = tokenizer.post_process_tokens(["my", "name"], ["pair"]) + assert result == ["[CLS]", "my", "name", "[SEP]", "pair", "[SEP]"] + + # Without special tokens + result = tokenizer.post_process_tokens(["my", "name"], add_special_tokens=False) + assert result == ["my", "name"] + + def test_post_process_ids(self): + from tokenizers.processors import BertProcessing + + tokenizer = Tokenizer(BPE()) + tokenizer.add_special_tokens(["[SEP]", "[CLS]"]) + tokenizer.add_tokens(["my", "name", "is", "john", "pair"]) + tokenizer.post_processor = BertProcessing(("[SEP]", 0), ("[CLS]", 1)) + + # Single sequence + result = tokenizer.post_process_ids([2, 3]) + assert result == [1, 2, 3, 0] + + # With pair + result = tokenizer.post_process_ids([2, 3], [6]) + assert result == [1, 2, 3, 0, 6, 0] + + # Without special tokens + result = tokenizer.post_process_ids([2, 3], add_special_tokens=False) + assert result == [2, 3] + def test_multiprocessing_with_parallelism(self): tokenizer = Tokenizer(BPE()) multiprocessing_with_parallelism(tokenizer, False) diff --git a/tokenizers/src/pre_tokenizers/byte_level.rs b/tokenizers/src/pre_tokenizers/byte_level.rs index 8bc0f30af..c928f4d10 100644 --- a/tokenizers/src/pre_tokenizers/byte_level.rs +++ b/tokenizers/src/pre_tokenizers/byte_level.rs @@ -197,6 +197,34 @@ impl PostProcessor for ByteLevel { Ok(encodings) //::default_process(encodings, add_special_tokens) } + + fn process_tokens( + &self, + tokens: Vec, + pair_tokens: Option>, + _add_special_tokens: bool, + ) -> Result> { + // ByteLevel doesn't add any special tokens, just concatenates + let mut result = tokens; + if let Some(pair) = pair_tokens { + result.extend(pair); + } + Ok(result) + } + + fn process_ids( + &self, + ids: Vec, + pair_ids: Option>, + _add_special_tokens: bool, + ) -> Result> { + // ByteLevel doesn't add any special tokens, just concatenates + let mut result = ids; + if let Some(pair) = pair_ids { + result.extend(pair); + } + Ok(result) + } } pub fn process_offsets(encoding: &mut Encoding, add_prefix_space: bool) { diff --git a/tokenizers/src/processors/bert.rs b/tokenizers/src/processors/bert.rs index a1cab8abd..d92578ed6 100644 --- a/tokenizers/src/processors/bert.rs +++ b/tokenizers/src/processors/bert.rs @@ -190,6 +190,58 @@ impl PostProcessor for BertProcessing { Ok(encodings) } + + fn process_tokens( + &self, + tokens: Vec, + pair_tokens: Option>, + add_special_tokens: bool, + ) -> Result> { + if !add_special_tokens { + let mut result = tokens; + if let Some(pair) = pair_tokens { + result.extend(pair); + } + return Ok(result); + } + + let mut result = vec![self.cls.0.clone()]; + result.extend(tokens); + result.push(self.sep.0.clone()); + + if let Some(pair) = pair_tokens { + result.extend(pair); + result.push(self.sep.0.clone()); + } + + Ok(result) + } + + fn process_ids( + &self, + ids: Vec, + pair_ids: Option>, + add_special_tokens: bool, + ) -> Result> { + if !add_special_tokens { + let mut result = ids; + if let Some(pair) = pair_ids { + result.extend(pair); + } + return Ok(result); + } + + let mut result = vec![self.cls.1]; + result.extend(ids); + result.push(self.sep.1); + + if let Some(pair) = pair_ids { + result.extend(pair); + result.push(self.sep.1); + } + + Ok(result) + } } #[cfg(test)] diff --git a/tokenizers/src/processors/mod.rs b/tokenizers/src/processors/mod.rs index 869cc6891..0db9b26f2 100644 --- a/tokenizers/src/processors/mod.rs +++ b/tokenizers/src/processors/mod.rs @@ -50,6 +50,38 @@ impl PostProcessor for PostProcessorWrapper { Self::Sequence(bl) => bl.process_encodings(encodings, add_special_tokens), } } + + fn process_tokens( + &self, + tokens: Vec, + pair_tokens: Option>, + add_special_tokens: bool, + ) -> Result> { + match self { + Self::Bert(bert) => bert.process_tokens(tokens, pair_tokens, add_special_tokens), + Self::ByteLevel(bl) => bl.process_tokens(tokens, pair_tokens, add_special_tokens), + Self::Roberta(roberta) => roberta.process_tokens(tokens, pair_tokens, add_special_tokens), + Self::Template(template) => { + template.process_tokens(tokens, pair_tokens, add_special_tokens) + } + Self::Sequence(seq) => seq.process_tokens(tokens, pair_tokens, add_special_tokens), + } + } + + fn process_ids( + &self, + ids: Vec, + pair_ids: Option>, + add_special_tokens: bool, + ) -> Result> { + match self { + Self::Bert(bert) => bert.process_ids(ids, pair_ids, add_special_tokens), + Self::ByteLevel(bl) => bl.process_ids(ids, pair_ids, add_special_tokens), + Self::Roberta(roberta) => roberta.process_ids(ids, pair_ids, add_special_tokens), + Self::Template(template) => template.process_ids(ids, pair_ids, add_special_tokens), + Self::Sequence(seq) => seq.process_ids(ids, pair_ids, add_special_tokens), + } + } } impl_enum_from!(BertProcessing, PostProcessorWrapper, Bert); diff --git a/tokenizers/src/processors/roberta.rs b/tokenizers/src/processors/roberta.rs index f2a47a9d3..870d78071 100644 --- a/tokenizers/src/processors/roberta.rs +++ b/tokenizers/src/processors/roberta.rs @@ -231,6 +231,62 @@ impl PostProcessor for RobertaProcessing { Ok(encodings) } + + fn process_tokens( + &self, + tokens: Vec, + pair_tokens: Option>, + add_special_tokens: bool, + ) -> Result> { + if !add_special_tokens { + let mut result = tokens; + if let Some(pair) = pair_tokens { + result.extend(pair); + } + return Ok(result); + } + + // Roberta: ... ... + let mut result = vec![self.cls.0.clone()]; + result.extend(tokens); + result.push(self.sep.0.clone()); + + if let Some(pair) = pair_tokens { + result.push(self.sep.0.clone()); // Extra before pair + result.extend(pair); + result.push(self.sep.0.clone()); + } + + Ok(result) + } + + fn process_ids( + &self, + ids: Vec, + pair_ids: Option>, + add_special_tokens: bool, + ) -> Result> { + if !add_special_tokens { + let mut result = ids; + if let Some(pair) = pair_ids { + result.extend(pair); + } + return Ok(result); + } + + // Roberta: ... ... + let mut result = vec![self.cls.1]; + result.extend(ids); + result.push(self.sep.1); + + if let Some(pair) = pair_ids { + result.push(self.sep.1); // Extra before pair + result.extend(pair); + result.push(self.sep.1); + } + + Ok(result) + } } #[cfg(test)] diff --git a/tokenizers/src/processors/sequence.rs b/tokenizers/src/processors/sequence.rs index f44cf54ac..0a456bc8d 100644 --- a/tokenizers/src/processors/sequence.rs +++ b/tokenizers/src/processors/sequence.rs @@ -66,6 +66,36 @@ impl PostProcessor for Sequence { } Ok(encodings) } + + fn process_tokens( + &self, + mut tokens: Vec, + mut pair_tokens: Option>, + add_special_tokens: bool, + ) -> Result> { + for processor in &self.processors { + // After first processor, combine into single sequence + let result = processor.process_tokens(tokens, pair_tokens, add_special_tokens)?; + tokens = result; + pair_tokens = None; + } + Ok(tokens) + } + + fn process_ids( + &self, + mut ids: Vec, + mut pair_ids: Option>, + add_special_tokens: bool, + ) -> Result> { + for processor in &self.processors { + // After first processor, combine into single sequence + let result = processor.process_ids(ids, pair_ids, add_special_tokens)?; + ids = result; + pair_ids = None; + } + Ok(ids) + } } #[cfg(test)] diff --git a/tokenizers/src/processors/template.rs b/tokenizers/src/processors/template.rs index 50fac99df..b8984e28d 100644 --- a/tokenizers/src/processors/template.rs +++ b/tokenizers/src/processors/template.rs @@ -683,6 +683,78 @@ impl PostProcessor for TemplateProcessing { let encodings = self.apply_template(template, encodings, add_special_tokens)?; Ok(encodings) } + + fn process_tokens( + &self, + tokens: Vec, + pair_tokens: Option>, + add_special_tokens: bool, + ) -> Result> { + let template = if pair_tokens.is_some() { + &self.pair.0 + } else { + &self.single.0 + }; + + let mut result = Vec::new(); + + for piece in template { + match piece { + Piece::Sequence { id, .. } => { + let seq_tokens = match id { + Sequence::A => &tokens, + Sequence::B => pair_tokens.as_ref().unwrap_or(&tokens), + }; + result.extend(seq_tokens.iter().cloned()); + } + Piece::SpecialToken { id, .. } => { + if add_special_tokens { + if let Some(tok) = self.special_tokens.0.get(id) { + result.extend(tok.tokens.iter().cloned()); + } + } + } + } + } + + Ok(result) + } + + fn process_ids( + &self, + ids: Vec, + pair_ids: Option>, + add_special_tokens: bool, + ) -> Result> { + let template = if pair_ids.is_some() { + &self.pair.0 + } else { + &self.single.0 + }; + + let mut result = Vec::new(); + + for piece in template { + match piece { + Piece::Sequence { id, .. } => { + let seq_ids = match id { + Sequence::A => &ids, + Sequence::B => pair_ids.as_ref().unwrap_or(&ids), + }; + result.extend(seq_ids.iter().copied()); + } + Piece::SpecialToken { id, .. } => { + if add_special_tokens { + if let Some(tok) = self.special_tokens.0.get(id) { + result.extend(tok.ids.iter().copied()); + } + } + } + } + } + + Ok(result) + } } #[cfg(test)] diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index cedabeebc..250a61363 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -123,6 +123,24 @@ pub trait PostProcessor { encodings: Vec, add_special_tokens: bool, ) -> Result>; + + /// Process a list of tokens (and optionally a pair) and return the processed tokens. + /// This is a simplified interface that only handles the token strings. + fn process_tokens( + &self, + tokens: Vec, + pair_tokens: Option>, + add_special_tokens: bool, + ) -> Result>; + + /// Process a list of token IDs (and optionally a pair) and return the processed IDs. + /// This is a simplified interface that only handles the token IDs. + fn process_ids( + &self, + ids: Vec, + pair_ids: Option>, + add_special_tokens: bool, + ) -> Result>; } impl dyn PostProcessor { pub fn default_process( @@ -1257,6 +1275,64 @@ where Ok(final_encoding) } + /// Post-process a list of tokens (and optionally a pair) and return the processed tokens. + /// This is a simplified interface that only handles the token strings, without the full + /// Encoding information. + /// + /// # Arguments + /// * `tokens` - The main sequence of tokens + /// * `pair_tokens` - An optional pair sequence of tokens + /// * `add_special_tokens` - Whether to add special tokens + /// + /// # Returns + /// A list of tokens with special tokens added according to the post-processor + pub fn post_process_tokens( + &self, + tokens: Vec, + pair_tokens: Option>, + add_special_tokens: bool, + ) -> Result> { + if let Some(processor) = &self.post_processor { + processor.process_tokens(tokens, pair_tokens, add_special_tokens) + } else { + // Default: just concatenate the sequences + let mut result = tokens; + if let Some(pair) = pair_tokens { + result.extend(pair); + } + Ok(result) + } + } + + /// Post-process a list of token IDs (and optionally a pair) and return the processed IDs. + /// This is a simplified interface that only handles the token IDs, without the full + /// Encoding information. + /// + /// # Arguments + /// * `ids` - The main sequence of token IDs + /// * `pair_ids` - An optional pair sequence of token IDs + /// * `add_special_tokens` - Whether to add special tokens + /// + /// # Returns + /// A list of token IDs with special tokens added according to the post-processor + pub fn post_process_ids( + &self, + ids: Vec, + pair_ids: Option>, + add_special_tokens: bool, + ) -> Result> { + if let Some(processor) = &self.post_processor { + processor.process_ids(ids, pair_ids, add_special_tokens) + } else { + // Default: just concatenate the sequences + let mut result = ids; + if let Some(pair) = pair_ids { + result.extend(pair); + } + Ok(result) + } + } + fn get_n_added_tokens(&self, is_pair: bool) -> usize { if let Some(processor) = &self.post_processor { processor.added_tokens(is_pair) diff --git a/tokenizers/tests/serialization.rs b/tokenizers/tests/serialization.rs index dc0c95a57..471eb8ac3 100644 --- a/tokenizers/tests/serialization.rs +++ b/tokenizers/tests/serialization.rs @@ -17,7 +17,7 @@ use tokenizers::pre_tokenizers::whitespace::Whitespace; use tokenizers::pre_tokenizers::PreTokenizerWrapper; use tokenizers::processors::bert::BertProcessing; use tokenizers::processors::PostProcessorWrapper; -use tokenizers::{SplitDelimiterBehavior, Tokenizer, TokenizerImpl}; +use tokenizers::{PostProcessor, SplitDelimiterBehavior, Tokenizer, TokenizerImpl}; #[test] fn bpe_serde() { @@ -248,3 +248,49 @@ fn bpe_with_dropout_serde() { fn test_deserialize_long_file() { let _tokenizer = Tokenizer::from_file("data/albert-base-v1-tokenizer.json").unwrap(); } + +#[test] +fn test_bert_process_tokens() { + let processor = BertProcessing::new(("[SEP]".into(), 102), ("[CLS]".into(), 101)); + let tokens = vec!["Hello".into(), "world".into()]; + let result = processor + .process_tokens(tokens, None, true) + .unwrap(); + assert_eq!( + result, + vec!["[CLS]", "Hello", "world", "[SEP]"] + ); + + // With pair + let tokens = vec!["Hello".into()]; + let pair = Some(vec!["world".into()]); + let result = processor.process_tokens(tokens, pair, true).unwrap(); + assert_eq!( + result, + vec!["[CLS]", "Hello", "[SEP]", "world", "[SEP]"] + ); + + // Without special tokens + let tokens = vec!["Hello".into(), "world".into()]; + let result = processor.process_tokens(tokens, None, false).unwrap(); + assert_eq!(result, vec!["Hello", "world"]); +} + +#[test] +fn test_bert_process_ids() { + let processor = BertProcessing::new(("[SEP]".into(), 102), ("[CLS]".into(), 101)); + let ids = vec![10, 20]; + let result = processor.process_ids(ids, None, true).unwrap(); + assert_eq!(result, vec![101, 10, 20, 102]); + + // With pair + let ids = vec![10]; + let pair = Some(vec![20]); + let result = processor.process_ids(ids, pair, true).unwrap(); + assert_eq!(result, vec![101, 10, 102, 20, 102]); + + // Without special tokens + let ids = vec![10, 20]; + let result = processor.process_ids(ids, None, false).unwrap(); + assert_eq!(result, vec![10, 20]); +}