Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,14 @@ pytest
- 📚 Documentation
- 🐛 Bug fixes

**Quick template for new architecture support:**
```python
from quantllm import register_architecture, turbo

register_architecture("new-arch", base_model_type="llama")
model = turbo("org/new-arch-7b", base_model_fallback=True, trust_remote_code=True)
```

---

## 📜 License
Expand Down
36 changes: 36 additions & 0 deletions docs/guide/loading-models.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,42 @@ model = turbo(
)
```

### New Architecture Fallbacks (for very recent model releases)

If `transformers` does not recognize a just-released architecture yet, register a fallback family:

```python
from quantllm import turbo, register_architecture

# Map new architecture/model_type to a compatible base family
register_architecture("newmodel", base_model_type="llama")

model = turbo(
"new-model-org/NewModel-7B",
model_type_override="llama", # optional explicit override
base_model_fallback=True, # retry with resolved fallback config
trust_remote_code=True,
)
```

You can also load from config only (no checkpoint weights) while waiting for upstream support:

```python
model = turbo(
"new-model-org/NewModel-7B",
from_config_only=True,
trust_remote_code=True,
)
```

#### Fast contribution template for new architectures

1. Add a registration in your code or PR:
- `register_architecture("new-arch", base_model_type="llama")`
2. Validate loading with:
- `turbo("org/model", base_model_fallback=True, trust_remote_code=True)`
3. Add/extend a focused test in `tests/test_architecture_fallback.py`.

### Memory Options

```python
Expand Down
2 changes: 2 additions & 0 deletions quantllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from .core import (
turbo,
TurboModel,
register_architecture,
SmartConfig,
HardwareProfiler,
ModelAnalyzer,
Expand Down Expand Up @@ -117,6 +118,7 @@ def show_banner(force: bool = False):
# Main API
"turbo",
"TurboModel",
"register_architecture",
"SmartConfig",
"HardwareProfiler",
"ModelAnalyzer",
Expand Down
3 changes: 2 additions & 1 deletion quantllm/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .hardware import HardwareProfiler
from .smart_config import SmartConfig
from .model_analyzer import ModelAnalyzer
from .turbo_model import TurboModel, turbo
from .turbo_model import TurboModel, turbo, register_architecture
from .compilation import (
compile_model,
compile_for_inference,
Expand Down Expand Up @@ -51,6 +51,7 @@
"ModelAnalyzer",
"TurboModel",
"turbo",
"register_architecture",
# Compilation
"compile_model",
"compile_for_inference",
Expand Down
210 changes: 206 additions & 4 deletions quantllm/core/turbo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
"""

import os
import re
import shutil
import tempfile
from typing import Optional, Dict, Any, Union, List
import copy
from functools import lru_cache
from typing import Optional, Dict, Any, Union, List, Type
import torch
import torch.nn as nn
from transformers import (
Expand All @@ -32,6 +35,14 @@
"quantization": "Q4_K_M",
"push_quantization": None,
}
DEFAULT_ARCHITECTURE_FALLBACKS = {
"llama": "llama",
"mistral": "mistral",
"mixtral": "mistral",
"qwen": "qwen2",
"phi": "phi",
"gemma": "gemma",
}


class TurboModel:
Expand All @@ -58,6 +69,9 @@ class TurboModel:
>>> model.export("gguf", "my_model.gguf")
"""

_architecture_registry: Dict[str, str] = {}
_model_class_registry: Dict[str, Type[PreTrainedModel]] = {}

def __init__(
self,
model: PreTrainedModel,
Expand All @@ -82,6 +96,162 @@ def __init__(
self._lora_applied = False
self.export_push_config = self._build_export_push_config(export_push_config)
self.verbose = verbose

@classmethod
def register_architecture(
cls,
architecture: str,
*,
base_model_type: Optional[str] = None,
model_class: Optional[Type[PreTrainedModel]] = None,
) -> None:
"""
Register a new architecture alias and optional explicit model class.

Args:
architecture: Architecture or model type name to register
base_model_type: Base model family to fall back to (e.g., "llama")
model_class: Explicit model class with from_pretrained()
"""
normalized = architecture.lower().strip()
if not normalized:
raise ValueError("architecture must be a non-empty string")

if base_model_type:
cls._architecture_registry[normalized] = base_model_type.lower().strip()

if model_class is not None:
cls._model_class_registry[normalized] = model_class

@classmethod
def resolve_model_type(
cls,
model_name: str,
*,
config_model_type: Optional[str] = None,
model_type_override: Optional[str] = None,
) -> Optional[str]:
"""
Resolve model type using override, registry, and default family patterns.

If config_model_type is provided but unregistered, the original config value
is returned unchanged.
"""
if model_type_override:
return model_type_override.lower().strip()

model_type = (config_model_type or "").lower().strip()
if model_type:
return cls._architecture_registry.get(model_type, model_type)

name = model_name.lower()
for pattern, fallback in cls._architecture_registry.items():
if cls._matches_model_name_pattern(name, pattern):
return fallback

for pattern, fallback in DEFAULT_ARCHITECTURE_FALLBACKS.items():
if cls._matches_model_name_pattern(name, pattern):
return fallback

return None

@classmethod
def _matches_model_name_pattern(cls, model_name: str, pattern: str) -> bool:
"""Return True when pattern appears as a token in model_name."""
return cls._compiled_model_name_pattern(pattern).search(model_name) is not None

@staticmethod
@lru_cache(maxsize=256)
def _compiled_model_name_pattern(pattern: str):
"""Compile and cache token-boundary regex patterns for model-name matching."""
escaped = re.escape(pattern)
# Match architecture tokens as standalone chunks split by separators.
return re.compile(rf"(^|[^a-z0-9]){escaped}([^a-z0-9]|$)")

@staticmethod
def _should_apply_quantization(
quantize: bool,
bits: int,
from_config_only: bool,
) -> bool:
"""Check whether quantization arguments should be added for loading."""
return quantize and bits < 16 and not from_config_only

@classmethod
def _load_model_with_fallback(
cls,
model_name: str,
model_kwargs: Dict[str, Any],
*,
trust_remote_code: bool,
hf_config: Optional[Any],
model_type_override: Optional[str],
base_model_fallback: bool,
from_config_only: bool,
) -> PreTrainedModel:
"""Load model with architecture fallback and optional config-only mode."""
resolved_model_type = cls.resolve_model_type(
model_name,
config_model_type=getattr(hf_config, "model_type", None),
model_type_override=model_type_override,
)
resolved_config = hf_config

if hf_config is not None and resolved_model_type:
current_model_type = getattr(hf_config, "model_type", None)
if current_model_type != resolved_model_type:
resolved_config = copy.deepcopy(hf_config)
setattr(resolved_config, "model_type", resolved_model_type)

if from_config_only:
if resolved_config is None:
raise ValueError(
"from_config_only=True requires a loadable config. "
"Try trust_remote_code=True or set model_type_override."
)
return AutoModelForCausalLM.from_config(
resolved_config,
trust_remote_code=trust_remote_code,
torch_dtype=model_kwargs.get("torch_dtype"),
)

try:
return AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
except Exception as primary_error:
if not base_model_fallback:
raise
fallback_error = None

if resolved_config is not None:
fallback_kwargs = dict(model_kwargs)
fallback_kwargs["config"] = resolved_config
try:
return AutoModelForCausalLM.from_pretrained(model_name, **fallback_kwargs)
except Exception as fallback_config_error:
fallback_error = fallback_config_error

if resolved_model_type:
registered_cls = cls._model_class_registry.get(resolved_model_type)
if registered_cls is not None:
class_kwargs = dict(model_kwargs)
if resolved_config is not None:
class_kwargs["config"] = resolved_config
try:
return registered_cls.from_pretrained(model_name, **class_kwargs)
except Exception as fallback_registered_error:
fallback_error = fallback_registered_error

error_details = f" Last fallback error: {fallback_error}" if fallback_error else ""

raise RuntimeError(
"Failed to load model with AutoModelForCausalLM and fallback resolution.\n"
"Try one of:\n"
"1) Register with register_architecture(...) before loading.\n"
"2) Use model_type_override='<base_family>'.\n"
"3) Use from_config_only=True with a loadable config "
"(usually trust_remote_code=True)."
+ error_details
) from (fallback_error or primary_error)

@classmethod
def from_pretrained(
Expand All @@ -96,6 +266,9 @@ def from_pretrained(
# Advanced options
trust_remote_code: bool = True,
quantize: bool = True,
model_type_override: Optional[str] = None,
base_model_fallback: bool = True,
from_config_only: bool = False,
config_override: Optional[Dict[str, Any]] = None,
config: Optional[Dict[str, Any]] = None,
verbose: bool = True,
Expand All @@ -117,6 +290,9 @@ def from_pretrained(
dtype: Override dtype (default: bf16 if available, else fp16)
trust_remote_code: Trust remote code in model
quantize: Whether to quantize the model
model_type_override: Override detected model_type for very new architectures
base_model_fallback: Retry loading with resolved base model config on failure
from_config_only: Build model from config only (without loading weights)
config_override: Dict to override any auto-detected settings
config: Shared export/push config (format, quantization, push_format, etc.)
verbose: Print loading progress
Expand Down Expand Up @@ -196,6 +372,8 @@ def from_pretrained(
"torch_dtype": smart_config.dtype,
}

hf_config = None

# Check if model is already quantized to prevent conflicts
try:
from transformers import AutoConfig
Expand Down Expand Up @@ -225,7 +403,7 @@ def from_pretrained(
pass # Ignore config loading errors, proceed with defaults

# Apply quantization if requested
if quantize and smart_config.bits < 16:
if cls._should_apply_quantization(quantize, smart_config.bits, from_config_only):
model_kwargs.update(cls._get_quantization_kwargs(smart_config))

# Device map for memory management
Expand All @@ -240,9 +418,14 @@ def from_pretrained(
if verbose:
task = p.add_task("Downloading & Loading model...", total=None)

model = AutoModelForCausalLM.from_pretrained(
model = cls._load_model_with_fallback(
model_name,
**model_kwargs,
model_kwargs,
trust_remote_code=trust_remote_code,
hf_config=hf_config,
model_type_override=model_type_override,
base_model_fallback=base_model_fallback,
from_config_only=from_config_only,
)

if verbose:
Expand Down Expand Up @@ -1892,6 +2075,25 @@ def _replace_with_triton(self, module: nn.Module, bits: int) -> int:
return count


def register_architecture(
architecture: str,
*,
base_model_type: Optional[str] = None,
model_class: Optional[Type[PreTrainedModel]] = None,
) -> None:
"""
Register a new architecture alias and optional explicit model class.

Example:
>>> register_architecture("my-new-model", base_model_type="llama")
"""
TurboModel.register_architecture(
architecture,
base_model_type=base_model_type,
model_class=model_class,
)


def turbo(
model: str,
*,
Expand Down
Loading
Loading