Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,11 +335,15 @@ static bool common_params_handle_remote_preset(common_params & params, llama_exa
struct handle_model_result {
bool found_mmproj = false;
common_params_model mmproj;

bool found_mtp = false;
common_params_model mtp;
};

static handle_model_result common_params_handle_model(struct common_params_model & model,
const std::string & bearer_token,
bool offline) {
bool offline,
bool search_mtp = false) {
handle_model_result result;

if (!model.docker_repo.empty()) {
Expand All @@ -354,7 +358,7 @@ static handle_model_result common_params_handle_model(struct common_params_model
common_download_opts opts;
opts.bearer_token = bearer_token;
opts.offline = offline;
auto download_result = common_download_model(model, opts, true);
auto download_result = common_download_model(model, opts, true, search_mtp);

if (download_result.model_path.empty()) {
LOG_ERR("error: failed to download model from Hugging Face\n");
Expand All @@ -368,6 +372,11 @@ static handle_model_result common_params_handle_model(struct common_params_model
result.found_mmproj = true;
result.mmproj.path = download_result.mmproj_path;
}

if (!download_result.mtp_path.empty()) {
result.found_mtp = true;
result.mtp.path = download_result.mtp_path;
}
} else if (!model.url.empty()) {
if (model.path.empty()) {
auto f = string_split<std::string>(model.url, '#').front();
Expand Down Expand Up @@ -588,7 +597,11 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context

// handle model and download
if (!skip_model_download) {
auto res = common_params_handle_model(params.model, params.hf_token, params.offline);
const bool spec_type_mtp = std::find(params.speculative.types.begin(),
params.speculative.types.end(),
COMMON_SPECULATIVE_TYPE_MTP) != params.speculative.types.end();

auto res = common_params_handle_model(params.model, params.hf_token, params.offline, spec_type_mtp);
if (params.no_mmproj) {
params.mmproj = {};
} else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) {
Expand All @@ -602,6 +615,14 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
break;
}
}
// when --spec-type mtp is set and no draft model was provided explicitly,
// fall back to the MTP head discovered alongside the -hf model
if (spec_type_mtp && res.found_mtp &&
params.speculative.draft.mparams.path.empty() &&
params.speculative.draft.mparams.hf_repo.empty() &&
params.speculative.draft.mparams.url.empty()) {
params.speculative.draft.mparams.path = res.mtp.path;
}
common_params_handle_model(params.speculative.draft.mparams, params.hf_token, params.offline);
common_params_handle_model(params.vocoder.model, params.hf_token, params.offline);
}
Expand Down
55 changes: 42 additions & 13 deletions common/download.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -566,8 +566,11 @@ static hf_cache::hf_files get_split_files(const hf_cache::hf_files & files,
return result;
}

static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files,
const std::string & model) {
// pick the best sibling GGUF whose filename contains `keyword` (e.g. "mmproj" / "MTP"),
// preferring deeper shared directory prefix with the model, then closest quantization
static hf_cache::hf_file find_best_sibling(const hf_cache::hf_files & files,
const std::string & model,
const std::string & keyword) {
hf_cache::hf_file best;
size_t best_depth = 0;
int best_diff = 0;
Expand All @@ -579,20 +582,20 @@ static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files,

for (const auto & f : files) {
if (!string_ends_with(f.path, ".gguf") ||
f.path.find("mmproj") == std::string::npos) {
f.path.find(keyword) == std::string::npos) {
continue;
}

auto mmproj_parts = string_split<std::string>(f.path, '/');
auto mmproj_dir = mmproj_parts.end() - 1;
auto sib_parts = string_split<std::string>(f.path, '/');
auto sib_dir = sib_parts.end() - 1;

auto [_, dir] = std::mismatch(model_parts.begin(), model_dir,
mmproj_parts.begin(), mmproj_dir);
if (dir != mmproj_dir) {
sib_parts.begin(), sib_dir);
if (dir != sib_dir) {
continue;
}

size_t depth = dir - mmproj_parts.begin();
size_t depth = dir - sib_parts.begin();
auto bits = extract_quant_bits(f.path);
auto diff = std::abs(bits - model_bits);

Expand All @@ -606,6 +609,16 @@ static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files,
return best;
}

static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files,
const std::string & model) {
return find_best_sibling(files, model, "mmproj");
}

static hf_cache::hf_file find_best_mtp(const hf_cache::hf_files & files,
const std::string & model) {
return find_best_sibling(files, model, "MTP");
}

static bool gguf_filename_is_model(const std::string & filepath) {
if (!string_ends_with(filepath, ".gguf")) {
return false;
Expand All @@ -617,7 +630,8 @@ static bool gguf_filename_is_model(const std::string & filepath) {
}

return filename.find("mmproj") == std::string::npos &&
filename.find("imatrix") == std::string::npos;
filename.find("imatrix") == std::string::npos &&
filename.find("MTP") == std::string::npos;
Comment thread
am17an marked this conversation as resolved.
Outdated
}

static hf_cache::hf_file find_best_model(const hf_cache::hf_files & files,
Expand Down Expand Up @@ -673,11 +687,13 @@ struct hf_plan {
hf_cache::hf_file primary;
hf_cache::hf_files model_files;
hf_cache::hf_file mmproj;
hf_cache::hf_file mtp;
};

static hf_plan get_hf_plan(const common_params_model & model,
const common_download_opts & opts,
bool download_mmproj) {
bool download_mmproj,
bool download_mtp) {
hf_plan plan;
hf_cache::hf_files all;

Expand Down Expand Up @@ -723,6 +739,10 @@ static hf_plan get_hf_plan(const common_params_model & model,
plan.mmproj = find_best_mmproj(all, primary.path);
}

if (download_mtp) {
plan.mtp = find_best_mtp(all, primary.path);
}

return plan;
}

Expand Down Expand Up @@ -756,21 +776,25 @@ static std::vector<download_task> get_url_tasks(const common_params_model & mode

common_download_model_result common_download_model(const common_params_model & model,
const common_download_opts & opts,
bool download_mmproj) {
bool download_mmproj,
bool download_mtp) {
common_download_model_result result;
std::vector<download_task> tasks;
hf_plan hf;

bool is_hf = !model.hf_repo.empty();

if (is_hf) {
hf = get_hf_plan(model, opts, download_mmproj);
hf = get_hf_plan(model, opts, download_mmproj, download_mtp);
for (const auto & f : hf.model_files) {
tasks.push_back({f.url, f.local_path});
}
if (!hf.mmproj.path.empty()) {
tasks.push_back({hf.mmproj.url, hf.mmproj.local_path});
}
if (!hf.mtp.path.empty()) {
tasks.push_back({hf.mtp.url, hf.mtp.local_path});
}
} else if (!model.url.empty()) {
tasks = get_url_tasks(model);
} else {
Expand Down Expand Up @@ -807,6 +831,10 @@ common_download_model_result common_download_model(const common_params_model &
if (!hf.mmproj.path.empty()) {
result.mmproj_path = hf_cache::finalize_file(hf.mmproj);
}

if (!hf.mtp.path.empty()) {
result.mtp_path = hf_cache::finalize_file(hf.mtp);
}
} else {
result.model_path = model.path;
}
Expand Down Expand Up @@ -946,7 +974,8 @@ std::vector<common_cached_model_info> common_list_cached_models() {
for (const auto & f : files) {
auto split = get_gguf_split_info(f.path);
if (split.index != 1 || split.tag.empty() ||
split.prefix.find("mmproj") != std::string::npos) {
split.prefix.find("mmproj") != std::string::npos ||
split.prefix.find("MTP") != std::string::npos) {
continue;
}
if (seen.insert(f.repo_id + ":" + split.tag).second) {
Expand Down
7 changes: 5 additions & 2 deletions common/download.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ struct common_download_opts {
struct common_download_model_result {
std::string model_path;
std::string mmproj_path;
std::string mtp_path;
};

// Download model from HuggingFace repo or URL
Expand All @@ -83,12 +84,14 @@ struct common_download_model_result {
// when opts.offline=true, no network requests are made
// when download_mmproj=true, searches for mmproj in same directory as model or any parent directory
// then with the closest quantization bits
// when download_mtp=true, applies the same sibling search for an MTP-head GGUF
//
// returns result with model_path and mmproj_path (empty on failure)
// returns result with model_path, mmproj_path and mtp_path (empty when not found / on failure)
common_download_model_result common_download_model(
const common_params_model & model,
const common_download_opts & opts = {},
bool download_mmproj = false
bool download_mmproj = false,
bool download_mtp = false
);

// returns list of cached models
Expand Down
2 changes: 1 addition & 1 deletion common/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1198,7 +1198,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
LOG_WRN("%s: draft model is not specified - cannot use 'draft' type\n", __func__);
has_draft = false;
}
} else if (has_draft_model) {
} else if (has_draft_model && !has_mtp && !has_draft_eagle3) {
LOG_WRN("%s: draft model is specified but 'draft' speculative type is not explicitly enabled - enabling it\n", __func__);
has_draft = true;
}
Expand Down
78 changes: 77 additions & 1 deletion convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5559,24 +5559,77 @@ class _Qwen35MtpMixin:
gguf_writer: gguf.GGUFWriter
block_count: int
tensor_map: gguf.TensorNameMap
fname_out: Path
ftype: Any
metadata: Any

# When true, `--mtp` was passed: filter out trunk weights so the resulting
# GGUF carries only the MTP head and the shared embeddings/output tensors.
mtp_only: bool = False

# When true, `--no-mtp` was passed: drop `mtp.*` tensors and report block_count
# as the trunk-only layer count, producing a GGUF with no MTP head.
no_mtp: bool = False
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs to be added to ModelBase, you may need to use super() or add a getter to properly access these.


def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.block_count = self.hparams["num_hidden_layers"] + self.hparams.get("mtp_num_hidden_layers", 0)
self.block_count = self.hparams["num_hidden_layers"]
if not self.no_mtp:
self.block_count += self.hparams.get("mtp_num_hidden_layers", 0)
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)

@classmethod
def filter_tensors(cls, item):
name, _ = item
if name.startswith("mtp."):
# Qwen3Next drops `mtp.*` tensors; Qwen3.5/3.6 use them by default. `--no-mtp` opts out.
if cls.no_mtp:
return None
return item
return super().filter_tensors(item) # ty: ignore[unresolved-attribute]

def set_gguf_parameters(self):
super().set_gguf_parameters() # ty: ignore[unresolved-attribute]
if self.no_mtp:
return
if (n := self.hparams.get("mtp_num_hidden_layers", 0)) > 0:
self.gguf_writer.add_nextn_predict_layers(n)

def prepare_metadata(self, vocab_only: bool):
super().prepare_metadata(vocab_only=vocab_only) # ty: ignore[unresolved-attribute]

if not self.mtp_only:
return

output_type: str = self.ftype.name.partition("_")[2]

if self.fname_out.is_dir():
fname_default: str = gguf.naming_convention(
self.metadata.name, self.metadata.basename, self.metadata.finetune,
self.metadata.version, size_label=None, output_type=output_type, model_type=None)
self.fname_out = self.fname_out / f"{Path(fname_default).stem}-MTP.gguf"
else:
stem = self.fname_out.stem
self.fname_out = self.fname_out.parent / f"{stem}-MTP{self.fname_out.suffix}"

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Multimodal Qwen3.5/3.6 wrap the text model under `model.language_model.*`.
if name.startswith("model.language_model."):
name = "model." + name[len("model.language_model."):]
elif name.startswith("language_model."):
name = name[len("language_model."):]

if self.mtp_only:
# In --mtp mode keep only the MTP block plus the shared embedding/output tensors
# that the standalone MTP graph references at inference time.
keep = (
name.startswith("mtp.") or
name in ("model.embed_tokens.weight", "model.norm.weight", "lm_head.weight") or
name in ("embed_tokens.weight", "norm.weight")
)
if not keep:
return

# Remap MTP block tensors to llama.cpp's layer-indexed nextn naming.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Multimodal Qwen3.5/3.6 wrap the text model under `model.language_model.*`.
if name.startswith("model.language_model."):
name = "model." + name[len("model.language_model."):]
elif name.startswith("language_model."):
name = name[len("language_model."):]
if self.mtp_only:
# In --mtp mode keep only the MTP block plus the shared embedding/output tensors
# that the standalone MTP graph references at inference time.
keep = (
name.startswith("mtp.") or
name in ("model.embed_tokens.weight", "model.norm.weight", "lm_head.weight") or
name in ("embed_tokens.weight", "norm.weight")
)
if not keep:
return
# Remap MTP block tensors to llama.cpp's layer-indexed nextn naming.
# Remap MTP block tensors to llama.cpp's layer-indexed nextn naming.

The language_model stuff should be obsolete, and mtp_only can go in filter_tensors, no?

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you take a look again?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm drowning ATM, don't have time to look into the details, but preferably the mtp flags should go into ModelBase even though it may complicate access from a classmethod.

If it works as-is right now, we can flag it for a later refactor instead.

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is part of ModelBase now

# HF: mtp.layers.0.* (transformer block at MTP slot 0)
# mtp.fc / mtp.pre_fc_norm_embedding / mtp.pre_fc_norm_hidden / mtp.norm
Expand Down Expand Up @@ -14034,6 +14087,14 @@ def parse_args() -> argparse.Namespace:
"--mmproj", action="store_true",
help="(Experimental) Export multimodal projector (mmproj) for vision models. This will only work on some vision models. A prefix 'mmproj-' will be added to the output file name.",
)
parser.add_argument(
"--mtp", action="store_true",
help="(Experimental) Export only the multi-token prediction (MTP) head as a separate GGUF, suitable for use as a speculative draft. Output file name will get a '-MTP' suffix.",
)
parser.add_argument(
"--no-mtp", action="store_true",
help="(Experimental) Exclude the multi-token prediction (MTP) head from the converted GGUF. Pair with --mtp on a second run to publish trunk and MTP as two files. Note: the split form duplicates embeddings, so the bundled default is more space-efficient overall.",
)
parser.add_argument(
"--mistral-format", action="store_true",
help="Whether the model is stored following the Mistral format.",
Expand Down Expand Up @@ -14193,6 +14254,18 @@ def main() -> None:
else:
model_class = MistralModel

if args.mtp and args.no_mtp:
logger.error("--mtp and --no-mtp are mutually exclusive")
sys.exit(1)

if (args.mtp or args.no_mtp) and not issubclass(model_class, _Qwen35MtpMixin):
logger.error("--mtp / --no-mtp are only supported for Qwen3.5/3.6 text variants today")
sys.exit(1)

# set on the class so __init__ sees the correct mode when computing block_count
if args.no_mtp:
model_class.no_mtp = True

Check failure on line 14267 in convert_hf_to_gguf.py

View workflow job for this annotation

GitHub Actions / python type-check

ty (unresolved-attribute)

convert_hf_to_gguf.py:14267:13: unresolved-attribute: Unresolved attribute `no_mtp` on type `type[ModelBase]`.

model_instance = model_class(dir_model, output_type, fname_out,
is_big_endian=args.bigendian, use_temp_file=args.use_temp_file,
eager=args.no_lazy,
Expand All @@ -14205,6 +14278,9 @@ def main() -> None:
fuse_gate_up_exps=args.fuse_gate_up_exps
)

if args.mtp:
model_instance.mtp_only = True

Check failure on line 14282 in convert_hf_to_gguf.py

View workflow job for this annotation

GitHub Actions / python type-check

ty (unresolved-attribute)

convert_hf_to_gguf.py:14282:13: unresolved-attribute: Unresolved attribute `mtp_only` on type `ModelBase`

if args.vocab_only:
logger.info("Exporting model vocab...")
model_instance.write_vocab()
Expand Down
6 changes: 6 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,11 @@ extern "C" {
LLAMA_SPLIT_MODE_TENSOR = 3,
};

enum llama_context_type {
LLAMA_CONTEXT_TYPE_DEFAULT = 0,
LLAMA_CONTEXT_TYPE_MTP = 1,
Comment thread
am17an marked this conversation as resolved.
Outdated
};

// TODO: simplify (https://github.com/ggml-org/llama.cpp/pull/9294#pullrequestreview-2286561979)
typedef struct llama_token_data {
llama_token id; // token id
Expand Down Expand Up @@ -339,6 +344,7 @@ extern "C" {
int32_t n_threads; // number of threads to use for generation
int32_t n_threads_batch; // number of threads to use for batch processing

enum llama_context_type ctx_type; // set the context type (e.g. MTP)
enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
enum llama_attention_type attention_type; // attention type to use for embeddings
Expand Down
2 changes: 0 additions & 2 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" },
{ LLM_ARCH_QWEN35, "qwen35" },
{ LLM_ARCH_QWEN35MOE, "qwen35moe" },
{ LLM_ARCH_QWEN35_MTP, "qwen35_mtp" },
{ LLM_ARCH_QWEN35MOE_MTP, "qwen35moe_mtp" },
{ LLM_ARCH_PHI2, "phi2" },
{ LLM_ARCH_PHI3, "phi3" },
{ LLM_ARCH_PHIMOE, "phimoe" },
Expand Down
Loading
Loading