Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion bindings/python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,6 @@ tempfile = "3.10"
pyo3 = { version = "0.28", features = ["auto-initialize", "experimental-inspect"] }

[features]
default = ["ext-module"]
default = ["ext-module", "parity-aware-bpe"]
ext-module = ["pyo3/extension-module"]
parity-aware-bpe = ["tokenizers/parity-aware-bpe"]
86 changes: 86 additions & 0 deletions bindings/python/py_src/tokenizers/trainers.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,92 @@ class BpeTrainer(Trainer):
@vocab_size.setter
def vocab_size(self, /, vocab_size: int) -> None: ...

@final
class ParityBpeTrainer:
def __getstate__(self, /) -> Any: ...
def __new__(
cls,
/,
num_merges: int = 32000,
variant: str = "base",
min_frequency: int = 0,
ratio: Sequence[float] | None = None,
global_merges: int = 0,
window_size: int = 100,
alpha: float = 2.0,
total_symbols: bool = False,
special_tokens: list | None = None,
show_progress: bool = True,
limit_alphabet: int | None = None,
initial_alphabet: Sequence[str] | None = None,
continuing_subword_prefix: str | None = None,
end_of_word_suffix: str | None = None,
max_token_length: int | None = None,
) -> ParityBpeTrainer:
"""Create and return a new object. See help(type) for accurate signature."""
...
def __repr__(self, /) -> str:
"""Return repr(self)."""
...
def __setstate__(self, /, state: Any) -> None: ...
def __str__(self, /) -> str:
"""Return str(self)."""
...
@property
def alpha(self, /) -> float: ...
@alpha.setter
def alpha(self, /, v: float) -> None: ...
@property
def continuing_subword_prefix(self, /) -> str | None: ...
@continuing_subword_prefix.setter
def continuing_subword_prefix(self, /, v: str | None) -> None: ...
@property
def end_of_word_suffix(self, /) -> str | None: ...
@end_of_word_suffix.setter
def end_of_word_suffix(self, /, v: str | None) -> None: ...
@property
def global_merges(self, /) -> int: ...
@global_merges.setter
def global_merges(self, /, v: int) -> None: ...
@property
def initial_alphabet(self, /) -> list[str]: ...
@initial_alphabet.setter
def initial_alphabet(self, /, alphabet: Sequence[str]) -> None: ...
@property
def limit_alphabet(self, /) -> int | None: ...
@limit_alphabet.setter
def limit_alphabet(self, /, v: int | None) -> None: ...
@property
def max_token_length(self, /) -> int | None: ...
@max_token_length.setter
def max_token_length(self, /, v: int | None) -> None: ...
@property
def min_frequency(self, /) -> int: ...
@min_frequency.setter
def min_frequency(self, /, v: int) -> None: ...
@property
def num_merges(self, /) -> int: ...
@num_merges.setter
def num_merges(self, /, v: int) -> None: ...
@property
def show_progress(self, /) -> bool: ...
@show_progress.setter
def show_progress(self, /, v: bool) -> None: ...
@property
def special_tokens(self, /) -> list[AddedToken]: ...
@special_tokens.setter
def special_tokens(self, /, special_tokens: list) -> None: ...
@property
def total_symbols(self, /) -> bool: ...
@total_symbols.setter
def total_symbols(self, /, v: bool) -> None: ...
@property
def variant(self, /) -> str: ...
@property
def window_size(self, /) -> int: ...
@window_size.setter
def window_size(self, /, v: int) -> None: ...

class Trainer:
"""
Base class for all trainers
Expand Down
1 change: 1 addition & 0 deletions bindings/python/py_src/tokenizers/trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .. import trainers

BpeTrainer = trainers.BpeTrainer
ParityBpeTrainer = trainers.ParityBpeTrainer
Trainer = trainers.Trainer
UnigramTrainer = trainers.UnigramTrainer
WordLevelTrainer = trainers.WordLevelTrainer
Expand Down
250 changes: 250 additions & 0 deletions bindings/python/py_src/tokenizers/trainers/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,256 @@ class WordLevelTrainer(Trainer):
""" """
pass

class ParityBpeTrainer:
"""
Trainer for parity-aware BPE that ensures cross-lingual fairness in tokenization.

Unlike standard BPE, this trainer takes one Python iterator per language and
balances merge operations across languages using a development set or target
compression ratios. The single training entry point is :meth:`train_from_iterator`,
the multi-corpus analogue of :meth:`tokenizers.Tokenizer.train_from_iterator`.

Args:
num_merges (:obj:`int`, `optional`):
Number of BPE merge operations to perform. Defaults to ``32000``.

variant (:obj:`str`, `optional`):
Algorithm variant: ``"base"`` (default) or ``"window"`` (moving-window balancing).

min_frequency (:obj:`int`, `optional`):
Minimum pair frequency to merge. Defaults to ``0``.

global_merges (:obj:`int`, `optional`):
Number of initial standard BPE merges before switching to parity mode.
Defaults to ``0``.

window_size (:obj:`int`, `optional`):
Window size for the ``"window"`` variant. Defaults to ``100``.

alpha (:obj:`float`, `optional`):
Alpha parameter for the ``"window"`` variant. Defaults to ``2.0``.

total_symbols (:obj:`bool`, `optional`):
If True, subtract unique character count from ``num_merges``.
Defaults to ``False``.

special_tokens (:obj:`List[Union[str, AddedToken]]`, `optional`):
A list of special tokens the model should know of.

show_progress (:obj:`bool`, `optional`):
Whether to show progress bars while training. Defaults to ``True``.

limit_alphabet (:obj:`int`, `optional`):
The maximum different characters to keep in the alphabet.

initial_alphabet (:obj:`List[str]`, `optional`):
A list of characters to include in the initial alphabet, even
if not seen in the training dataset.

continuing_subword_prefix (:obj:`str`, `optional`):
A prefix to be used for every subword that is not a beginning-of-word.

end_of_word_suffix (:obj:`str`, `optional`):
A suffix to be used for every subword that is a end-of-word.

max_token_length (:obj:`int`, `optional`):
Prevents creating tokens longer than the specified size.
"""
def __init__(
self,
num_merges=32000,
variant="base",
min_frequency=0,
ratio=None,
global_merges=0,
window_size=100,
alpha=2.0,
total_symbols=False,
special_tokens=None,
show_progress=True,
limit_alphabet=None,
initial_alphabet=None,
continuing_subword_prefix=None,
end_of_word_suffix=None,
max_token_length=None,
):
pass

def __repr__(self) -> str: ...
def __str__(self) -> str: ...
def __getstate__(self): ...
def __setstate__(self, state): ...
def train_from_iterator(
self,
tokenizer,
train_iterators,
dev_iterators=None,
ratio=None,
):
"""
Train a user-configured tokenizer with parity-aware BPE from per-language
Python iterators.

This is the multi-corpus analogue of
:meth:`~tokenizers.Tokenizer.train_from_iterator`: file I/O happens in
Python, so users can pull data from plain text, parquet (via ``pyarrow``),
``datasets``, etc.

Args:
tokenizer (:class:`~tokenizers.Tokenizer`):
A tokenizer instance to train. Its pre-tokenizer (and optionally
normalizer) should already be configured.

train_iterators (:obj:`List[Iterator]`):
One Python iterator per language, each yielding ``str`` or
``List[str]``.

dev_iterators (:obj:`List[Iterator]`, `optional`):
One Python iterator per language, used to drive parity-aware
language selection. Must have the same length as
``train_iterators``.

ratio (:obj:`List[float]`, `optional`):
Target compression ratios per language (alternative to
``dev_iterators``).
"""
pass

@property
def special_tokens(self):
""" """
pass

@special_tokens.setter
def special_tokens(self, value):
""" """
pass

@property
def show_progress(self):
""" """
pass

@show_progress.setter
def show_progress(self, value):
""" """
pass

@property
def limit_alphabet(self):
""" """
pass

@limit_alphabet.setter
def limit_alphabet(self, value):
""" """
pass

@property
def initial_alphabet(self):
""" """
pass

@initial_alphabet.setter
def initial_alphabet(self, value):
""" """
pass

@property
def continuing_subword_prefix(self):
""" """
pass

@continuing_subword_prefix.setter
def continuing_subword_prefix(self, value):
""" """
pass

@property
def end_of_word_suffix(self):
""" """
pass

@end_of_word_suffix.setter
def end_of_word_suffix(self, value):
""" """
pass

@property
def max_token_length(self):
""" """
pass

@max_token_length.setter
def max_token_length(self, value):
""" """
pass

@property
def min_frequency(self):
""" """
pass

@min_frequency.setter
def min_frequency(self, value):
""" """
pass

@property
def num_merges(self):
""" """
pass

@num_merges.setter
def num_merges(self, value):
""" """
pass

@property
def variant(self):
""" """
pass

@property
def global_merges(self):
""" """
pass

@global_merges.setter
def global_merges(self, value):
""" """
pass

@property
def window_size(self):
""" """
pass

@window_size.setter
def window_size(self, value):
""" """
pass

@property
def alpha(self):
""" """
pass

@alpha.setter
def alpha(self, value):
""" """
pass

@property
def total_symbols(self):
""" """
pass

@total_symbols.setter
def total_symbols(self, value):
""" """
pass

class WordPieceTrainer(Trainer):
"""
Trainer capable of training a WordPiece model
Expand Down
Loading