Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions tensorrt_llm/_torch/models/modeling_radio.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,10 +1024,11 @@ def __init__(self,
self.model_config = copy.deepcopy(model_config)
if self.model_config.quant_config is not None:
if disable_quantization:
# The basic method `apply_quant_config_exclude_modules` in DecoderModelForCausalLM keeps the kv_cache_quant_algo so we also keep it here.
self.model_config.quant_config = QuantConfig(
kv_cache_quant_algo=self.model_config.quant_config.
kv_cache_quant_algo)
# Vision encoder runs with kv_cache_manager=None, so there is no KV cache to
# quantize. Keeping kv_cache_quant_algo would make FlashInfer raise:
# "FP8 KV cache is not supported without a KV cache manager" for FP8 LLM checkpoints
# that specify FP8 KV Cache.
self.model_config.quant_config = QuantConfig()

self.model_config = dataclasses.replace(
self.model_config, attn_backend=vision_attn_backend)
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_lists/test-db/l0_a10.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ l0_a10:
- unittest/_torch/modeling/test_modeling_cohere2.py
- unittest/_torch/modeling/test_nemotron_nano_preprocessing.py
- unittest/_torch/modeling/test_modeling_parakeet.py
- unittest/_torch/modeling/test_modeling_radio.py
- unittest/_torch/sampler/test_trtllm_sampler.py
- unittest/_torch/executor/test_async_transfer_manager.py
- unittest/_torch/executor/test_scheduler_serializable_output.py
Expand Down
95 changes: 95 additions & 0 deletions tests/unittest/_torch/modeling/test_modeling_radio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
from unittest import mock

import pytest
import torch
from transformers import PretrainedConfig

from tensorrt_llm._torch import model_config as model_config_lib
from tensorrt_llm._torch.models import modeling_radio
from tensorrt_llm._torch.models.modeling_radio import RADIOVisionModel
from tensorrt_llm.models.modeling_utils import QuantConfig
from tensorrt_llm.quantization.mode import QuantAlgo

_TINY_VIT = modeling_radio.VITTIMMConfig(
embed_dim=64,
depth=2,
num_attention_heads=2,
intermediate_size=128,
img_size=32,
)


def _make_vision_config():
"""Minimal PretrainedConfig mimicking Nemotron-Nano-V3's `vision_config`.

Patterned on the `vision_config` block of config.json shipped with newer nemotron nano
multimodal models; fields pared down to what RADIOVisionModel reads.
"""
config = PretrainedConfig()
config.patch_size = 16
config.adaptor_names = None
config.feature_normalizer_config = None
config.inter_feature_normalizer_config = None
config.max_resolution = 64
config.vitdet_window_size = None
config.preferred_resolution = (32, 32)
config.video_temporal_patch_size = 1
config.separate_video_embedder = True
config.torch_dtype = torch.bfloat16
config.args = {
"model": "vit_tiny_test",
"in_chans": None,
"input_size": None,
"drop": 0.0,
"cpe_max_size": 64,
"cls_token_per_teacher": False,
"teachers": [{"name": "dummy"}],
"register_multiple": None,
"cpe_num_registers": None,
}
return config


@pytest.fixture
def tiny_vit_config():
with mock.patch.dict(
modeling_radio.VIT_TIMM_CONFIG_BY_NAME,
{"vit_tiny_test": _TINY_VIT},
):
yield


def _make_fp8_model_config():
vision_config = _make_vision_config()
quant_config = QuantConfig(
quant_algo=QuantAlgo.FP8,
kv_cache_quant_algo=QuantAlgo.FP8,
)
return model_config_lib.ModelConfig(
pretrained_config=vision_config,
quant_config=quant_config,
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
def test_radio_fp8_parent_kv_cache_does_not_leak_into_vit(tiny_vit_config):
"""When the parent LLM uses FP8 KV cache, the RADIO vision encoder must not inherit it.

A ViT has no KV cache (kv_cache_manager=None). If `kv_cache_quant_algo=FP8` leaks into the
vision tower, FlashInfer raises at forward time about it not being supported.
"""
vision_model = RADIOVisionModel(_make_fp8_model_config(), disable_quantization=True)

device = torch.device("cuda")
dtype = torch.bfloat16
vision_model = vision_model.to(device).to(dtype)

# 32x32 image: multiple of `patch_size=16`, so `min_resolution_step` is satisfied.
pixel_values = torch.randn(1, 3, 32, 32, device=device, dtype=dtype)

with torch.inference_mode():
features = vision_model.forward(pixel_values)

assert features.shape[0] == 1
assert features.shape[-1] == _TINY_VIT.embed_dim
Loading