Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import asyncio
import json
import logging
import math
import re
import time
Expand All @@ -25,6 +24,9 @@
MAX_HISTORY_TOKENS = 128
MAX_HISTORY_TURNS = 6

# files we need from the HF repo besides the ONNX model
_TOKENIZER_FILES = ("tokenizer.json", "tokenizer_config.json")


def _download_from_hf_hub(repo_id: str, filename: str, **kwargs: Any) -> str:
from huggingface_hub import hf_hub_download
Expand Down Expand Up @@ -83,34 +85,22 @@ def _format_chat_ctx(self, chat_ctx: list[dict[str, Any]]) -> str:
new_chat_ctx.append(msg)
last_msg = msg

convo_text = self._tokenizer.apply_chat_template(
new_chat_ctx, add_generation_prompt=False, add_special_tokens=False, tokenize=False
convo_text: str = self._chat_template.render(
messages=new_chat_ctx, add_generation_prompt=False
)

# remove the EOU token from current utterance
ix = convo_text.rfind("<|im_end|>")
text = convo_text[:ix]
return text # type: ignore
return convo_text[:ix]

def initialize(self) -> None:
logger = logging.getLogger("transformers")
import numpy as np
import onnxruntime as ort # type: ignore
from huggingface_hub import errors
from jinja2 import Environment
from tokenizers import Tokenizer # type: ignore[import-untyped]

class _SuppressSpecific(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
msg = record.getMessage()
return not msg.startswith(
"None of PyTorch, TensorFlow >= 2.0, or Flax have been found."
)

filt = _SuppressSpecific()
# filter this log since it conflicts with the console CLI (since it directly prints to stdout)
logger.addFilter(filt)
try:
import onnxruntime as ort # type: ignore
from huggingface_hub import errors
from transformers import AutoTokenizer
finally:
logger.removeFilter(filt)
self._np = np

revision = self.__class__.model_revision()
try:
Expand All @@ -121,6 +111,13 @@ def filter(self, record: logging.LogRecord) -> bool:
revision=revision,
local_files_only=True,
)
tokenizer_json_path = _download_from_hf_hub(
HG_MODEL, "tokenizer.json", revision=revision, local_files_only=True
)
tokenizer_config_path = _download_from_hf_hub(
HG_MODEL, "tokenizer_config.json", revision=revision, local_files_only=True
)

sess_options = ort.SessionOptions()
sess_options.intra_op_num_threads = max(
1, min(math.ceil(hw.get_cpu_monitor().cpu_count()) // 2, 4)
Expand All @@ -130,12 +127,20 @@ def filter(self, record: logging.LogRecord) -> bool:
self._session = ort.InferenceSession(
local_path_onnx, providers=["CPUExecutionProvider"], sess_options=sess_options
)
self._tokenizer = AutoTokenizer.from_pretrained( # type: ignore[no-untyped-call]
HG_MODEL,
revision=revision,
local_files_only=True,
truncation_side="left",
)

self._tokenizer = Tokenizer.from_file(tokenizer_json_path)
# match the previous transformers behavior: left-truncate to MAX_HISTORY_TOKENS
self._tokenizer.enable_truncation(max_length=MAX_HISTORY_TOKENS, direction="left")

with open(tokenizer_config_path) as f:
tokenizer_config = json.load(f)
chat_template = tokenizer_config.get("chat_template")
if not chat_template:
raise RuntimeError(
f"tokenizer_config.json for {HG_MODEL}@{revision} has no chat_template"
)
# the EOU model templates are simple message loops; no custom helpers needed
self._chat_template = Environment(autoescape=False).from_string(chat_template)

except (errors.LocalEntryNotFoundError, OSError):
logger.error(
Expand All @@ -157,15 +162,9 @@ def run(self, data: bytes) -> bytes | None:

start_time = time.perf_counter()
text = self._format_chat_ctx(chat_ctx)
inputs = self._tokenizer(
text,
add_special_tokens=False,
return_tensors="np",
max_length=MAX_HISTORY_TOKENS,
truncation=True,
)
# run inference
outputs = self._session.run(None, {"input_ids": inputs["input_ids"].astype("int64")})
encoding = self._tokenizer.encode(text, add_special_tokens=False)
input_ids = self._np.asarray(encoding.ids, dtype=self._np.int64).reshape(1, -1)
outputs = self._session.run(None, {"input_ids": input_ids})
eou_probability = outputs[0].flatten()[-1]
end_time = time.perf_counter()

Expand All @@ -178,14 +177,11 @@ def run(self, data: bytes) -> bytes | None:

@classmethod
def _download_files(cls) -> None:
from transformers import AutoTokenizer

# ensure the tokenizer is downloaded
AutoTokenizer.from_pretrained(HG_MODEL, revision=cls.model_revision()) # type: ignore[no-untyped-call]
_download_from_hf_hub(
HG_MODEL, ONNX_FILENAME, subfolder="onnx", revision=cls.model_revision()
)
_download_from_hf_hub(HG_MODEL, "languages.json", revision=cls.model_revision())
revision = cls.model_revision()
for filename in _TOKENIZER_FILES:
_download_from_hf_hub(HG_MODEL, filename, revision=revision)
_download_from_hf_hub(HG_MODEL, ONNX_FILENAME, subfolder="onnx", revision=revision)
_download_from_hf_hub(HG_MODEL, "languages.json", revision=revision)


class EOUPlugin(Plugin):
Expand Down
3 changes: 2 additions & 1 deletion livekit-plugins/livekit-plugins-turn-detector/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ classifiers = [
]
dependencies = [
"livekit-agents>=1.5.8",
"transformers>=4.47.1,!=4.57.2,!=4.57.3", # 4.57.2-4.57.3 have a bug with local_files_only=True (huggingface/transformers#42369)
"tokenizers>=0.20,<1",
"huggingface-hub>=0.25",
"numpy>=1.26",
"onnxruntime>=1.18",
"jinja2",
Expand Down
Loading
Loading