-
Notifications
You must be signed in to change notification settings - Fork 11
MTP: clean-up #9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
61e916c
5b92839
46c0801
f87f0f4
d769c57
84f00ce
f6f29e6
0712378
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -5559,24 +5559,77 @@ class _Qwen35MtpMixin: | |||||||||||||||||||||||||||||||||||||||
| gguf_writer: gguf.GGUFWriter | ||||||||||||||||||||||||||||||||||||||||
| block_count: int | ||||||||||||||||||||||||||||||||||||||||
| tensor_map: gguf.TensorNameMap | ||||||||||||||||||||||||||||||||||||||||
| fname_out: Path | ||||||||||||||||||||||||||||||||||||||||
| ftype: Any | ||||||||||||||||||||||||||||||||||||||||
| metadata: Any | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # When true, `--mtp` was passed: filter out trunk weights so the resulting | ||||||||||||||||||||||||||||||||||||||||
| # GGUF carries only the MTP head and the shared embeddings/output tensors. | ||||||||||||||||||||||||||||||||||||||||
| mtp_only: bool = False | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # When true, `--no-mtp` was passed: drop `mtp.*` tensors and report block_count | ||||||||||||||||||||||||||||||||||||||||
| # as the trunk-only layer count, producing a GGUF with no MTP head. | ||||||||||||||||||||||||||||||||||||||||
| no_mtp: bool = False | ||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Needs to be added to |
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def __init__(self, *args, **kwargs): | ||||||||||||||||||||||||||||||||||||||||
| super().__init__(*args, **kwargs) | ||||||||||||||||||||||||||||||||||||||||
| self.block_count = self.hparams["num_hidden_layers"] + self.hparams.get("mtp_num_hidden_layers", 0) | ||||||||||||||||||||||||||||||||||||||||
| self.block_count = self.hparams["num_hidden_layers"] | ||||||||||||||||||||||||||||||||||||||||
| if not self.no_mtp: | ||||||||||||||||||||||||||||||||||||||||
| self.block_count += self.hparams.get("mtp_num_hidden_layers", 0) | ||||||||||||||||||||||||||||||||||||||||
| self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| @classmethod | ||||||||||||||||||||||||||||||||||||||||
| def filter_tensors(cls, item): | ||||||||||||||||||||||||||||||||||||||||
| name, _ = item | ||||||||||||||||||||||||||||||||||||||||
| if name.startswith("mtp."): | ||||||||||||||||||||||||||||||||||||||||
| # Qwen3Next drops `mtp.*` tensors; Qwen3.5/3.6 use them by default. `--no-mtp` opts out. | ||||||||||||||||||||||||||||||||||||||||
| if cls.no_mtp: | ||||||||||||||||||||||||||||||||||||||||
| return None | ||||||||||||||||||||||||||||||||||||||||
| return item | ||||||||||||||||||||||||||||||||||||||||
| return super().filter_tensors(item) # ty: ignore[unresolved-attribute] | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def set_gguf_parameters(self): | ||||||||||||||||||||||||||||||||||||||||
| super().set_gguf_parameters() # ty: ignore[unresolved-attribute] | ||||||||||||||||||||||||||||||||||||||||
| if self.no_mtp: | ||||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||||
| if (n := self.hparams.get("mtp_num_hidden_layers", 0)) > 0: | ||||||||||||||||||||||||||||||||||||||||
| self.gguf_writer.add_nextn_predict_layers(n) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def prepare_metadata(self, vocab_only: bool): | ||||||||||||||||||||||||||||||||||||||||
| super().prepare_metadata(vocab_only=vocab_only) # ty: ignore[unresolved-attribute] | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| if not self.mtp_only: | ||||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| output_type: str = self.ftype.name.partition("_")[2] | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| if self.fname_out.is_dir(): | ||||||||||||||||||||||||||||||||||||||||
| fname_default: str = gguf.naming_convention( | ||||||||||||||||||||||||||||||||||||||||
| self.metadata.name, self.metadata.basename, self.metadata.finetune, | ||||||||||||||||||||||||||||||||||||||||
| self.metadata.version, size_label=None, output_type=output_type, model_type=None) | ||||||||||||||||||||||||||||||||||||||||
| self.fname_out = self.fname_out / f"{Path(fname_default).stem}-MTP.gguf" | ||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||
| stem = self.fname_out.stem | ||||||||||||||||||||||||||||||||||||||||
| self.fname_out = self.fname_out.parent / f"{stem}-MTP{self.fname_out.suffix}" | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: | ||||||||||||||||||||||||||||||||||||||||
| # Multimodal Qwen3.5/3.6 wrap the text model under `model.language_model.*`. | ||||||||||||||||||||||||||||||||||||||||
| if name.startswith("model.language_model."): | ||||||||||||||||||||||||||||||||||||||||
| name = "model." + name[len("model.language_model."):] | ||||||||||||||||||||||||||||||||||||||||
| elif name.startswith("language_model."): | ||||||||||||||||||||||||||||||||||||||||
| name = name[len("language_model."):] | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| if self.mtp_only: | ||||||||||||||||||||||||||||||||||||||||
| # In --mtp mode keep only the MTP block plus the shared embedding/output tensors | ||||||||||||||||||||||||||||||||||||||||
| # that the standalone MTP graph references at inference time. | ||||||||||||||||||||||||||||||||||||||||
| keep = ( | ||||||||||||||||||||||||||||||||||||||||
| name.startswith("mtp.") or | ||||||||||||||||||||||||||||||||||||||||
| name in ("model.embed_tokens.weight", "model.norm.weight", "lm_head.weight") or | ||||||||||||||||||||||||||||||||||||||||
| name in ("embed_tokens.weight", "norm.weight") | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| if not keep: | ||||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # Remap MTP block tensors to llama.cpp's layer-indexed nextn naming. | ||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
The
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you take a look again? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm drowning ATM, don't have time to look into the details, but preferably the If it works as-is right now, we can flag it for a later refactor instead.
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is part of |
||||||||||||||||||||||||||||||||||||||||
| # HF: mtp.layers.0.* (transformer block at MTP slot 0) | ||||||||||||||||||||||||||||||||||||||||
| # mtp.fc / mtp.pre_fc_norm_embedding / mtp.pre_fc_norm_hidden / mtp.norm | ||||||||||||||||||||||||||||||||||||||||
|
|
@@ -14034,6 +14087,14 @@ def parse_args() -> argparse.Namespace: | |||||||||||||||||||||||||||||||||||||||
| "--mmproj", action="store_true", | ||||||||||||||||||||||||||||||||||||||||
| help="(Experimental) Export multimodal projector (mmproj) for vision models. This will only work on some vision models. A prefix 'mmproj-' will be added to the output file name.", | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| parser.add_argument( | ||||||||||||||||||||||||||||||||||||||||
| "--mtp", action="store_true", | ||||||||||||||||||||||||||||||||||||||||
| help="(Experimental) Export only the multi-token prediction (MTP) head as a separate GGUF, suitable for use as a speculative draft. Output file name will get a '-MTP' suffix.", | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| parser.add_argument( | ||||||||||||||||||||||||||||||||||||||||
| "--no-mtp", action="store_true", | ||||||||||||||||||||||||||||||||||||||||
| help="(Experimental) Exclude the multi-token prediction (MTP) head from the converted GGUF. Pair with --mtp on a second run to publish trunk and MTP as two files. Note: the split form duplicates embeddings, so the bundled default is more space-efficient overall.", | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| parser.add_argument( | ||||||||||||||||||||||||||||||||||||||||
| "--mistral-format", action="store_true", | ||||||||||||||||||||||||||||||||||||||||
| help="Whether the model is stored following the Mistral format.", | ||||||||||||||||||||||||||||||||||||||||
|
|
@@ -14193,6 +14254,18 @@ def main() -> None: | |||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||
| model_class = MistralModel | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| if args.mtp and args.no_mtp: | ||||||||||||||||||||||||||||||||||||||||
| logger.error("--mtp and --no-mtp are mutually exclusive") | ||||||||||||||||||||||||||||||||||||||||
| sys.exit(1) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| if (args.mtp or args.no_mtp) and not issubclass(model_class, _Qwen35MtpMixin): | ||||||||||||||||||||||||||||||||||||||||
| logger.error("--mtp / --no-mtp are only supported for Qwen3.5/3.6 text variants today") | ||||||||||||||||||||||||||||||||||||||||
| sys.exit(1) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # set on the class so __init__ sees the correct mode when computing block_count | ||||||||||||||||||||||||||||||||||||||||
| if args.no_mtp: | ||||||||||||||||||||||||||||||||||||||||
| model_class.no_mtp = True | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| model_instance = model_class(dir_model, output_type, fname_out, | ||||||||||||||||||||||||||||||||||||||||
| is_big_endian=args.bigendian, use_temp_file=args.use_temp_file, | ||||||||||||||||||||||||||||||||||||||||
| eager=args.no_lazy, | ||||||||||||||||||||||||||||||||||||||||
|
|
@@ -14205,6 +14278,9 @@ def main() -> None: | |||||||||||||||||||||||||||||||||||||||
| fuse_gate_up_exps=args.fuse_gate_up_exps | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| if args.mtp: | ||||||||||||||||||||||||||||||||||||||||
| model_instance.mtp_only = True | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| if args.vocab_only: | ||||||||||||||||||||||||||||||||||||||||
| logger.info("Exporting model vocab...") | ||||||||||||||||||||||||||||||||||||||||
| model_instance.write_vocab() | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.