Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions tensorrt_llm/serve/openai_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,18 @@
from tensorrt_llm.llmapi.reasoning_parser import ReasoningParserFactory


def _logit_bias_to_embedding_bias(logit_bias: Optional[Dict[str, float]],
vocab_size: int) -> Optional[torch.Tensor]:
def _logit_bias_to_embedding_bias(
logit_bias: Optional[Dict[str, float]],
vocab_size: Optional[int]) -> Optional[torch.Tensor]:
"""Convert OpenAI logit_bias dict to embedding_bias tensor for sampling."""
if logit_bias is None:
return None
if vocab_size is None:
raise ValueError(
"logit_bias requires a tokenizer, but the server was started "
"without one (e.g. num_postprocess_workers > 0). "
"Remove logit_bias from your request or set num_postprocess_workers=0."
)

# Create 1D zeros tensor as expected by executor API (will be unsqueezed to [1, vocab_size] internally)
embedding_bias = torch.zeros(vocab_size, dtype=torch.float32)
Expand Down Expand Up @@ -390,7 +397,7 @@ class CompletionRequest(OpenAIBaseModel):
# doc: end-completion-extra-params

def to_sampling_params(self,
vocab_size: int = 32000,
vocab_size: Optional[int] = None,
gather_generation_logits: bool = False,
backend: Optional[str] = None) -> SamplingParams:
sampling_params = SamplingParams(
Expand Down Expand Up @@ -752,7 +759,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
# doc: end-chat-completion-extra-params

def to_sampling_params(self,
vocab_size: int = 32000,
vocab_size: Optional[int] = None,
gather_generation_logits: bool = False,
reasoning_parser: Optional[str] = None,
backend: Optional[str] = None) -> SamplingParams:
Expand Down
13 changes: 9 additions & 4 deletions tensorrt_llm/serve/openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,12 @@ def postproc_worker_enabled(self) -> bool:

return True if self.generator.args.num_postprocess_workers > 0 else False

@property
def _vocab_size(self) -> Optional[int]:
if self.tokenizer is not None and self.tokenizer.tokenizer is not None:
return self.tokenizer.tokenizer.vocab_size
return None

@staticmethod
def create_error_response(
message: str,
Expand Down Expand Up @@ -1043,7 +1049,7 @@ async def chat_stream_generator(
# Pass the tokenizer vocabulary size so ``logit_bias`` can be
# expanded into an embedding bias tensor in the sampler.
sampling_params = request.to_sampling_params(
vocab_size=self.tokenizer.tokenizer.vocab_size,
vocab_size=self._vocab_size,
gather_generation_logits=self.generator.args.
gather_generation_logits,
reasoning_parser=self.generator.args.reasoning_parser,
Expand Down Expand Up @@ -1375,7 +1381,7 @@ async def generator_wrapper(generator: AsyncIterator[Any]):
# Pass the tokenizer vocabulary size so ``logit_bias`` can be
# expanded into an embedding bias tensor in the sampler.
sampling_params = request.to_sampling_params(
vocab_size=self.tokenizer.tokenizer.vocab_size,
vocab_size=self._vocab_size,
gather_generation_logits=self.generator.args.
gather_generation_logits,
backend=self.generator.args.backend)
Expand Down Expand Up @@ -1510,8 +1516,7 @@ async def create_streaming_generator(promise: RequestOutput,
request.stop_token_ids = harmony_stop_tokens

sampling_params = request.to_sampling_params(
vocab_size=self.tokenizer.tokenizer.vocab_size,
reasoning_parser="gpt_oss")
vocab_size=self._vocab_size, reasoning_parser="gpt_oss")
sampling_params.detokenize = False # Harmony adapter handles detokenization
disaggregated_params = to_llm_disaggregated_params(
request.disaggregated_params)
Expand Down
11 changes: 10 additions & 1 deletion tests/unittest/llmapi/apps/test_harmony_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@
get_harmony_adapter,
handle_streaming_response,
)
from tensorrt_llm.serve.openai_protocol import StreamOptions
from tensorrt_llm.serve.openai_protocol import StreamOptions, _logit_bias_to_embedding_bias
from tensorrt_llm.serve.openai_server import OpenAIServer

_harmony_available = True
except (ImportError, ModuleNotFoundError):
Expand Down Expand Up @@ -1192,5 +1193,13 @@ def test_stream_options_none_defaults_to_include(self):
harmony_adapter.cleanup_stream_state(request_id)


def test_none_tokenizer_num_postprocess_workers():
server = object.__new__(OpenAIServer)
server.tokenizer = None
assert server._vocab_size is None
with pytest.raises(ValueError, match="logit_bias requires a tokenizer"):
_logit_bias_to_embedding_bias({"0": 1.0}, vocab_size=None)


if __name__ == "__main__":
pytest.main([__file__, "-v"])
Loading