diff --git a/tensorrt_llm/_torch/models/modeling_radio.py b/tensorrt_llm/_torch/models/modeling_radio.py index 9286554e0da..454dcc0801b 100644 --- a/tensorrt_llm/_torch/models/modeling_radio.py +++ b/tensorrt_llm/_torch/models/modeling_radio.py @@ -1024,10 +1024,11 @@ def __init__(self, self.model_config = copy.deepcopy(model_config) if self.model_config.quant_config is not None: if disable_quantization: - # The basic method `apply_quant_config_exclude_modules` in DecoderModelForCausalLM keeps the kv_cache_quant_algo so we also keep it here. - self.model_config.quant_config = QuantConfig( - kv_cache_quant_algo=self.model_config.quant_config. - kv_cache_quant_algo) + # Vision encoder runs with kv_cache_manager=None, so there is no KV cache to + # quantize. Keeping kv_cache_quant_algo would make FlashInfer raise: + # "FP8 KV cache is not supported without a KV cache manager" for FP8 LLM checkpoints + # that specify FP8 KV Cache. + self.model_config.quant_config = QuantConfig() self.model_config = dataclasses.replace( self.model_config, attn_backend=vision_attn_backend) diff --git a/tests/integration/test_lists/test-db/l0_a10.yml b/tests/integration/test_lists/test-db/l0_a10.yml index aec96073607..f3a17e0158e 100644 --- a/tests/integration/test_lists/test-db/l0_a10.yml +++ b/tests/integration/test_lists/test-db/l0_a10.yml @@ -23,6 +23,7 @@ l0_a10: - unittest/_torch/modeling/test_modeling_cohere2.py - unittest/_torch/modeling/test_nemotron_nano_preprocessing.py - unittest/_torch/modeling/test_modeling_parakeet.py + - unittest/_torch/modeling/test_modeling_radio.py - unittest/_torch/sampler/test_trtllm_sampler.py - unittest/_torch/executor/test_async_transfer_manager.py - unittest/_torch/executor/test_scheduler_serializable_output.py diff --git a/tests/unittest/_torch/modeling/test_modeling_radio.py b/tests/unittest/_torch/modeling/test_modeling_radio.py new file mode 100644 index 00000000000..a4ba7104cc0 --- /dev/null +++ b/tests/unittest/_torch/modeling/test_modeling_radio.py @@ -0,0 +1,95 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +from unittest import mock + +import pytest +import torch +from transformers import PretrainedConfig + +from tensorrt_llm._torch import model_config as model_config_lib +from tensorrt_llm._torch.models import modeling_radio +from tensorrt_llm._torch.models.modeling_radio import RADIOVisionModel +from tensorrt_llm.models.modeling_utils import QuantConfig +from tensorrt_llm.quantization.mode import QuantAlgo + +_TINY_VIT = modeling_radio.VITTIMMConfig( + embed_dim=64, + depth=2, + num_attention_heads=2, + intermediate_size=128, + img_size=32, +) + + +def _make_vision_config(): + """Minimal PretrainedConfig mimicking Nemotron-Nano-V3's `vision_config`. + + Patterned on the `vision_config` block of config.json shipped with newer nemotron nano + multimodal models; fields pared down to what RADIOVisionModel reads. + """ + config = PretrainedConfig() + config.patch_size = 16 + config.adaptor_names = None + config.feature_normalizer_config = None + config.inter_feature_normalizer_config = None + config.max_resolution = 64 + config.vitdet_window_size = None + config.preferred_resolution = (32, 32) + config.video_temporal_patch_size = 1 + config.separate_video_embedder = True + config.torch_dtype = torch.bfloat16 + config.args = { + "model": "vit_tiny_test", + "in_chans": None, + "input_size": None, + "drop": 0.0, + "cpe_max_size": 64, + "cls_token_per_teacher": False, + "teachers": [{"name": "dummy"}], + "register_multiple": None, + "cpe_num_registers": None, + } + return config + + +@pytest.fixture +def tiny_vit_config(): + with mock.patch.dict( + modeling_radio.VIT_TIMM_CONFIG_BY_NAME, + {"vit_tiny_test": _TINY_VIT}, + ): + yield + + +def _make_fp8_model_config(): + vision_config = _make_vision_config() + quant_config = QuantConfig( + quant_algo=QuantAlgo.FP8, + kv_cache_quant_algo=QuantAlgo.FP8, + ) + return model_config_lib.ModelConfig( + pretrained_config=vision_config, + quant_config=quant_config, + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_radio_fp8_parent_kv_cache_does_not_leak_into_vit(tiny_vit_config): + """When the parent LLM uses FP8 KV cache, the RADIO vision encoder must not inherit it. + + A ViT has no KV cache (kv_cache_manager=None). If `kv_cache_quant_algo=FP8` leaks into the + vision tower, FlashInfer raises at forward time about it not being supported. + """ + vision_model = RADIOVisionModel(_make_fp8_model_config(), disable_quantization=True) + + device = torch.device("cuda") + dtype = torch.bfloat16 + vision_model = vision_model.to(device).to(dtype) + + # 32x32 image: multiple of `patch_size=16`, so `min_resolution_step` is satisfied. + pixel_values = torch.randn(1, 3, 32, 32, device=device, dtype=dtype) + + with torch.inference_mode(): + features = vision_model.forward(pixel_values) + + assert features.shape[0] == 1 + assert features.shape[-1] == _TINY_VIT.embed_dim