Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions docs/source/models/visual-generation.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,15 @@ Models are auto-detected from the checkpoint directory. Diffusers-format models
| **FLUX.1** | Yes | Yes | Yes | No [^1] | Yes | No | Yes | Yes | Yes |
| **FLUX.2** | Yes | Yes | Yes | No [^1] | Yes | No | Yes | Yes | Yes |
| **Wan 2.1** | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes |
| **Wan 2.2** | Yes | Yes | No | Yes | Yes | Yes | Yes | Yes | Yes |
| **LTX-2** | Yes | Yes | No | Yes | Yes | No | No | Yes | Yes |
| **Wan 2.2** | Yes | Yes | Yes [^2] | Yes | Yes | Yes | Yes | Yes | Yes |
| **LTX-2** | Yes | Yes | Yes [^3] | Yes | Yes | No | No | Yes | Yes |

[^1]: FLUX models use embedded guidance and do not have a separate negative prompt path, so CFG parallelism is not applicable.

[^2]: Wan 2.2 has two stage transformers; TeaCache requires explicit `teacache.coefficients` (high-noise) and `teacache.coefficients_2` (low-noise). There is no built-in coefficient table for Wan 2.2.

[^3]: LTX-2 has no built-in TeaCache coefficient table in TRT-LLM; set `teacache.coefficients` explicitly when enabling TeaCache.

## Quick Start

Here is a simple example to generate a video with Wan 2.1:
Expand Down Expand Up @@ -108,7 +112,7 @@ args = VisualGenArgs(

### TeaCache

TeaCache caches transformer outputs when timestep embeddings change slowly between denoising steps, skipping redundant computation. Enable with `teacache.enable_teacache: true` (YAML config). The `teacache_thresh` parameter controls the similarity threshold.
TeaCache caches transformer outputs when timestep embeddings change slowly between denoising steps, skipping redundant computation. Enable with `teacache.enable_teacache: true` (YAML config). The `teacache_thresh` parameter controls the similarity threshold. For Wan 2.2, set both `coefficients` and `coefficients_2` (YAML or CLI). For LTX-2, set `coefficients` when enabling TeaCache (no built-in table). Other models (e.g. FLUX.1, FLUX.2, Wan 2.1) can omit `coefficients` to use the built-in checkpoint table.

### Multi-GPU Parallelism

Expand Down
2 changes: 2 additions & 0 deletions examples/visual_gen/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,8 @@ python visual_gen_ltx2.py \
| `--image_cond_strength` | — | ✓ | 1.0 | Image conditioning strength |
| `--enable_teacache` | ✓ | ✓ | — | False | Cache optimization |
| `--teacache_thresh` | ✓ | ✓ | — | 0.2 | TeaCache similarity threshold |
| `--teacache_coefficients` | ✓ | ✓ | — | *(omit)* | Optional polynomial coeffs; overrides built-in table |
| `--use_ret_steps` | ✓ | ✓ | — | False | TeaCache retention-steps mode (WAN/FLUX tables) |
| `--attention_backend` | ✓ | ✓ | — | VANILLA | `VANILLA`, `TRTLLM`, or `FA4` |
| `--cfg_size` | — | ✓ | — | 1 | CFG parallelism |
| `--ulysses_size` | ✓ | ✓ | — | 1 | Sequence parallelism |
Expand Down
2 changes: 2 additions & 0 deletions examples/visual_gen/serve/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ Before running these examples, ensure you have:
```
For LTX-2, you need to provide a proper text_encoder_path in `./configs/ltx2.yml`.

**TeaCache:** Example YAML files set `enable_teacache` and `teacache_thresh` only. Omit `coefficients` to use each pipeline’s **built-in** coefficient table (checkpoint path matching). Add `coefficients: [ ... ]` under `teacache` only when you need to override those defaults.

## Examples

Current supported & tested models:
Expand Down
13 changes: 13 additions & 0 deletions examples/visual_gen/visual_gen_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,17 @@ def parse_args():
choices=["dynamic", "static"],
help="SCM steps_computation_policy (default: dynamic if not overridden).",
)
parser.add_argument(
"--teacache_coefficients",
nargs="+",
type=float,
default=None,
metavar="FLOAT",
help=(
"Optional TeaCache polynomial coefficients (overrides checkpoint table). "
"Example: --teacache_coefficients 1.0 0.0 0.5"
),
)

# Quantization
parser.add_argument(
Expand Down Expand Up @@ -283,6 +294,8 @@ def _teacache_config_from_args(args) -> TeaCacheConfig:
kwargs: dict = {"use_ret_steps": args.use_ret_steps}
if args.teacache_thresh is not None:
kwargs["teacache_thresh"] = args.teacache_thresh
if args.teacache_coefficients is not None:
kwargs["coefficients"] = list(args.teacache_coefficients)
return TeaCacheConfig(**kwargs)


Expand Down
27 changes: 27 additions & 0 deletions examples/visual_gen/visual_gen_wan_i2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,29 @@ def parse_args():
choices=["dynamic", "static"],
help="SCM steps_computation_policy (default: dynamic if not overridden).",
)
parser.add_argument(
"--teacache_coefficients",
nargs="+",
type=float,
default=None,
metavar="FLOAT",
help=(
"Optional TeaCache polynomial coefficients (overrides checkpoint table). "
"Example: --teacache_coefficients 1.0 0.0 0.5"
),
)
parser.add_argument(
"--teacache_coefficients_2",
nargs="+",
type=float,
default=None,
metavar="FLOAT",
help=(
"Second polynomial for Wan 2.2 low-noise transformer_2 (requires "
"--teacache_coefficients for the high-noise transformer). "
"Ignored for Wan 2.1."
),
)

# Quantization
parser.add_argument(
Expand Down Expand Up @@ -246,6 +269,10 @@ def _teacache_config_from_args(args) -> TeaCacheConfig:
kwargs: dict = {"use_ret_steps": args.use_ret_steps}
if args.teacache_thresh is not None:
kwargs["teacache_thresh"] = args.teacache_thresh
if args.teacache_coefficients is not None:
kwargs["coefficients"] = list(args.teacache_coefficients)
if args.teacache_coefficients_2 is not None:
kwargs["coefficients_2"] = list(args.teacache_coefficients_2)
return TeaCacheConfig(**kwargs)


Expand Down
27 changes: 27 additions & 0 deletions examples/visual_gen/visual_gen_wan_t2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,29 @@ def parse_args():
choices=["dynamic", "static"],
help="SCM steps_computation_policy (default: dynamic if not overridden).",
)
parser.add_argument(
"--teacache_coefficients",
nargs="+",
type=float,
default=None,
metavar="FLOAT",
help=(
"Optional TeaCache polynomial coefficients (overrides checkpoint table). "
"Example: --teacache_coefficients 1.0 0.0 0.5"
),
)
parser.add_argument(
"--teacache_coefficients_2",
nargs="+",
type=float,
default=None,
metavar="FLOAT",
help=(
"Second polynomial for Wan 2.2 low-noise transformer_2 (requires "
"--teacache_coefficients for the high-noise transformer). "
"Ignored for Wan 2.1."
),
)

# Quantization
parser.add_argument(
Expand Down Expand Up @@ -246,6 +269,10 @@ def _teacache_config_from_args(args) -> TeaCacheConfig:
kwargs: dict = {"use_ret_steps": args.use_ret_steps}
if args.teacache_thresh is not None:
kwargs["teacache_thresh"] = args.teacache_thresh
if args.teacache_coefficients is not None:
kwargs["coefficients"] = list(args.teacache_coefficients)
if args.teacache_coefficients_2 is not None:
kwargs["coefficients_2"] = list(args.teacache_coefficients_2)
return TeaCacheConfig(**kwargs)


Expand Down
23 changes: 19 additions & 4 deletions tensorrt_llm/_torch/visual_gen/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,26 @@ class BaseCacheConfig(StrictBaseModel):


class TeaCacheConfig(BaseCacheConfig):
"""TeaCache step-caching acceleration config."""
"""TeaCache step-caching acceleration config.

Attributes:
teacache_thresh: Distance threshold for cache decisions (lower = more caching)
use_ret_steps: Use aggressive warmup mode (5 steps) vs minimal (1 step)
coefficients: Polynomial coefficients for rescaling embedding distances
(rescaled_distance = poly(raw_distance)). None uses the pipeline built-in
coefficient table (checkpoint path matching); a non-None list overrides that table.
coefficients_2: Second polynomial (Wan 2.2 dual-transformer low-noise stage only).
Required together with coefficients when enabling TeaCache on Wan 2.2.
ret_steps / cutoff_steps / num_steps: Filled at runtime by TeaCacheBackend.refresh().
_cnt: Internal step counter (reset per generation)
"""

cache_backend: Literal["teacache"] = "teacache"
teacache_thresh: float = PydanticField(0.2, gt=0.0)
use_ret_steps: bool = False

coefficients: List[float] = PydanticField(default_factory=lambda: [1.0, 0.0])
coefficients: Optional[List[float]] = None
coefficients_2: Optional[List[float]] = None

# Runtime state fields (initialized by TeaCacheBackend.refresh)
ret_steps: Optional[int] = None
Expand All @@ -164,9 +177,11 @@ class TeaCacheConfig(BaseCacheConfig):
@model_validator(mode="after")
def validate_teacache(self) -> "TeaCacheConfig":
"""Validate TeaCache configuration."""
# Validate coefficients
if len(self.coefficients) == 0:
# Validate coefficients (when provided)
if self.coefficients is not None and len(self.coefficients) == 0:
raise ValueError("TeaCache coefficients list cannot be empty")
if self.coefficients_2 is not None and len(self.coefficients_2) == 0:
raise ValueError("TeaCache coefficients_2 list cannot be empty")

# Validate ret_steps if set
if self.ret_steps is not None and self.ret_steps < 0:
Expand Down
44 changes: 22 additions & 22 deletions tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,32 +297,32 @@ def load_weights(self, weights: dict) -> None:
self.transformer.eval()

def post_load_weights(self) -> None:
"""Post-load setup: TeaCache registration."""
"""Post-load setup: cache acceleration (TeaCache or Cache-DiT)."""
super().post_load_weights()
if self.transformer is not None:
# Register TeaCache extractor for FLUX.2 (must be after device placement)
# Only set guidance_param_name for variants with guidance_embeds
guidance_param = "guidance" if self.transformer.guidance_embeds else None
forward_params = [
"hidden_states",
"encoder_hidden_states",
"timestep",
"img_ids",
"txt_ids",
"guidance",
"return_dict",
]
register_extractor_from_config(
ExtractorConfig(
model_class_name="Flux2Transformer2DModel",
timestep_embed_fn=self._compute_flux2_timestep_embedding,
guidance_param_name=guidance_param,
forward_params=forward_params,
return_dict_default=False,
if self.model_config.cache_backend == "teacache":
# Register TeaCache extractor for FLUX.2 (must be after device placement)
# Only set guidance_param_name for variants with guidance_embeds
guidance_param = "guidance" if self.transformer.guidance_embeds else None
forward_params = [
"hidden_states",
"encoder_hidden_states",
"timestep",
"img_ids",
"txt_ids",
"guidance",
"return_dict",
]
register_extractor_from_config(
ExtractorConfig(
model_class_name="Flux2Transformer2DModel",
timestep_embed_fn=self._compute_flux2_timestep_embedding,
guidance_param_name=guidance_param,
forward_params=forward_params,
return_dict_default=False,
)
)
)

# TeaCache or Cache-DiT
self._setup_cache_acceleration(self.transformer, FLUX2_TEACACHE_COEFFICIENTS)

@property
Expand Down
21 changes: 13 additions & 8 deletions tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from transformers import Gemma3ForConditionalGeneration, GemmaTokenizerFast

from tensorrt_llm._torch.utils import make_weak_ref
from tensorrt_llm._torch.visual_gen.cache.teacache import CacheContext
from tensorrt_llm._torch.visual_gen.cache.teacache import CacheContext, register_extractor
from tensorrt_llm._torch.visual_gen.config import PipelineComponent
from tensorrt_llm._torch.visual_gen.cuda_graph_runner import CUDAGraphRunner, CUDAGraphRunnerConfig
from tensorrt_llm._torch.visual_gen.output import MediaOutput
Expand Down Expand Up @@ -791,13 +791,18 @@ def post_load_weights(self) -> None:
"""Finalize after weight loading: TeaCache, Cache-DiT, derived attributes."""
super().post_load_weights()

# TODO: TeaCache disabled: LTX2_TEACACHE_COEFFICIENTS are unverified.
# To re-enable, uncomment the following lines and verify coefficients.
# register_extractor(
# "LTXModel",
# LTX2TeaCacheExtractor(self._compute_ltx2_timestep_embedding),
# )
# self._setup_teacache(self.transformer, coefficients=LTX2_TEACACHE_COEFFICIENTS)
# LTX-2: single transformer (one DiT for video+audio); TeaCache only with explicit coefficients.
if self.transformer is not None and self.model_config.cache_backend == "teacache":
if self.model_config.teacache.coefficients is None:
raise ValueError(
"TeaCache on LTX-2 requires explicit teacache.coefficients "
"(no built-in coefficient table)."
)
register_extractor(
"LTXModel",
LTX2TeaCacheExtractor(self._compute_ltx2_timestep_embedding),
)
self._setup_cache_acceleration(self.transformer, coefficients=None)

# Cache-DiT
if self.transformer is not None and self.model_config.cache_backend == "cache_dit":
Expand Down
39 changes: 30 additions & 9 deletions tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from tensorrt_llm._torch.visual_gen.cache.teacache import (
ExtractorConfig,
TeaCacheBackend,
register_extractor_from_config,
)
from tensorrt_llm._torch.visual_gen.config import PipelineComponent
Expand Down Expand Up @@ -84,13 +85,6 @@ def __init__(self, model_config):
self.boundary_ratio = getattr(model_config.pretrained_config, "boundary_ratio", None)
self.is_wan22 = self.boundary_ratio is not None

# Validate TeaCache compatibility before allocating GPU memory
if self.is_wan22 and model_config.cache_backend == "teacache":
raise ValueError(
"TeaCache is not supported for Wan 2.2 T2V models. "
"Use cache_backend='none' or 'cache_dit' (not 'teacache')."
)

super().__init__(model_config)

def _compute_wan_timestep_embedding(self, module, timestep=None, **kwargs):
Expand Down Expand Up @@ -290,13 +284,40 @@ def post_load_weights(self) -> None:
else:
if self.model_config.cache_backend == "cache_dit":
self._setup_cache_acceleration(self.transformer, coefficients=None)
# TeaCache is not supported for Wan 2.2 unless using Cache-DiT.
self.transformer_cache_backend = self.cache_accelerator
self.transformer_cache_backend = self.cache_accelerator

if self.transformer_2 is not None:
if hasattr(self.transformer_2, "post_load_weights"):
self.transformer_2.post_load_weights()

# Wan 2.2 TeaCache after both transformers' post_load_weights (FP8 scales, etc.)
if (
self.transformer is not None
and self.transformer_2 is not None
and self.is_wan22
and self.model_config.cache_backend == "teacache"
):
tc = self.model_config.teacache
if tc.coefficients is None or tc.coefficients_2 is None:
raise ValueError(
"Wan 2.2 TeaCache requires explicit teacache.coefficients and "
"teacache.coefficients_2 (high-noise and low-noise stage polynomials). "
"There is no built-in coefficient table for Wan 2.2."
)
cfg_high = tc.model_copy()
cfg_low = tc.model_copy(update={"coefficients": tc.coefficients_2})
logger.info("TeaCache: Initializing (Wan 2.2 high-noise transformer)...")
self.cache_backend = TeaCacheBackend(cfg_high)
self.cache_backend.enable(self.transformer)
self.transformer_cache_backend = self.cache_backend
logger.info("TeaCache: Initializing (Wan 2.2 low-noise transformer_2)...")
self.transformer_2_cache_backend = TeaCacheBackend(cfg_low)
self.transformer_2_cache_backend.enable(self.transformer_2)
self._teacache_backends = [
self.cache_backend,
self.transformer_2_cache_backend,
]

def _run_warmup(self, height: int, width: int, num_frames: int, steps: int) -> None:
with torch.no_grad():
self.forward(
Expand Down
Loading
Loading