diff --git a/common/arg.cpp b/common/arg.cpp index 27048e41149..4dbe5a3a87e 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3550,12 +3550,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } ).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_MODEL")); add_opt(common_arg( - {"--spec-type"}, "[none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]", + {"--spec-type"}, "[none|mtp|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]", string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n", common_speculative_type_to_str(params.speculative.type).c_str()), [](common_params & params, const std::string & value) { if (value == "none") { params.speculative.type = COMMON_SPECULATIVE_TYPE_NONE; + } else if (value == "mtp") { + params.speculative.type = COMMON_SPECULATIVE_TYPE_MTP; } else if (value == "ngram-cache") { params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_CACHE; } else if (value == "ngram-simple") { diff --git a/common/common.h b/common/common.h index 6d13ac40e1f..33612ef5f5d 100644 --- a/common/common.h +++ b/common/common.h @@ -159,6 +159,7 @@ enum common_speculative_type { COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding COMMON_SPECULATIVE_TYPE_DRAFT, // draft model COMMON_SPECULATIVE_TYPE_EAGLE3, // eagle draft model + COMMON_SPECULATIVE_TYPE_MTP, // multi-token prediction head loaded from the target GGUF COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values diff --git a/common/speculative.cpp b/common/speculative.cpp index 098066d15ce..3b36e04bdac 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -3,6 +3,7 @@ #include "common.h" #include "ggml.h" #include "llama.h" +#include "../src/llama-ext.h" // staging API: llama_set_embeddings_pre_norm / llama_get_embeddings_pre_norm_ith (used by MTP) #include "log.h" #include "ngram-cache.h" #include "ngram-map.h" @@ -23,6 +24,7 @@ const std::vector common_speculative_types = { COMMON_SPECULATIVE_TYPE_NONE, COMMON_SPECULATIVE_TYPE_DRAFT, COMMON_SPECULATIVE_TYPE_EAGLE3, + COMMON_SPECULATIVE_TYPE_MTP, COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, @@ -34,6 +36,7 @@ const std::map common_speculative_typ {"none", COMMON_SPECULATIVE_TYPE_NONE}, {"draft", COMMON_SPECULATIVE_TYPE_DRAFT}, {"eagle3", COMMON_SPECULATIVE_TYPE_EAGLE3}, + {"mtp", COMMON_SPECULATIVE_TYPE_MTP}, {"ngram_simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE}, {"ngram_map_k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K}, {"ngram_map_k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V}, @@ -375,6 +378,290 @@ struct common_speculative_state_eagle3 : public common_speculative_impl { } }; +struct common_speculative_state_mtp : public common_speculative_impl { + common_params_speculative_draft params; // reuses the draft-model params slot (ctx_tgt/ctx_dft) + + llama_batch batch; + + std::vector smpls; + + int32_t n_embd = 0; + + // Per-sequence cross-batch carryover: pair (h_p, x_{p+1}) at MTP pos p+1. + // The last h-row of one process() call needs the first token of the NEXT + // call to pair with, so it's stashed here until that next call fires. + std::vector> pending_h; // [n_seq][n_embd] + + std::vector i_batch_beg; + std::vector i_batch_end; + + common_speculative_state_mtp(const common_params_speculative & params, uint32_t n_seq) + : common_speculative_impl(COMMON_SPECULATIVE_TYPE_MTP, n_seq) + , params(params.draft) + { + auto * ctx_tgt = this->params.ctx_tgt; + auto * ctx_dft = this->params.ctx_dft; + GGML_ASSERT(ctx_tgt && ctx_dft && "MTP requires ctx_tgt and ctx_dft to be set"); + + n_embd = llama_model_n_embd(llama_get_model(ctx_dft)); + + const int32_t n_b = (int32_t) llama_n_batch(ctx_dft); + batch = llama_batch_init(/*n_tokens=*/ n_b, /*embd=*/ n_embd, /*n_seq_max=*/ 1); + // llama_batch_init allocates only one of token/embd; MTP needs both. + // TODO: fix, how to call without malloc + batch.token = (llama_token *) malloc(sizeof(llama_token) * n_b); + + smpls.resize(n_seq); + for (auto & s : smpls) { + common_params_sampling sparams; + sparams.no_perf = false; + sparams.top_k = 10; + sparams.samplers = { COMMON_SAMPLER_TYPE_TOP_K }; + s.reset(common_sampler_init(llama_get_model(ctx_dft), sparams)); + } + + llama_set_embeddings_pre_norm(ctx_tgt, true); + llama_set_embeddings_pre_norm(ctx_dft, true); + + pending_h.assign(n_seq, std::vector(n_embd, 0.0f)); + + i_batch_beg.assign(n_seq, -1); + i_batch_end.assign(n_seq, -1); + } + + ~common_speculative_state_mtp() override { + if (batch.token != nullptr) { + free(batch.token); + batch.token = nullptr; + } + llama_batch_free(batch); + } + + void begin(llama_seq_id seq_id, const llama_tokens & prompt) override { + const int32_t N = (int32_t) prompt.size(); + if (N <= 0) { + return; + } + auto * ctx_dft = this->params.ctx_dft; + const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id); + if (pos_max < N - 1) { + LOG_WRN("%s: ctx_dft pos_max=%d < N-1=%d — " + "process() hook may not have run on every prefill ubatch " + "(need_embd / logits=1 on every prompt position?). " + "Drafts may degrade.\n", + __func__, (int) pos_max, N - 1); + } + } + + bool process(const llama_batch & batch_in) override { + if (batch_in.n_tokens <= 0) { + return true; + } + + // TODO: how to make it work with vision tokens? + if (batch_in.token == nullptr || batch_in.embd != nullptr) { + return true; + } + + const int32_t n_tokens = batch_in.n_tokens; + + // remember the frist and last batch index for each sequence + std::fill(i_batch_beg.begin(), i_batch_beg.end(), -1); + std::fill(i_batch_end.begin(), i_batch_end.end(), -1); + + for (int k = 0; k < n_tokens; ++k) { + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + GGML_ASSERT(batch_in.n_seq_id[k] == 1); + + if (batch_in.seq_id[k][0] == seq_id) { + i_batch_end[seq_id] = k; + if (i_batch_beg[seq_id] < 0) { + i_batch_beg[seq_id] = k; + } + } + } + } + + auto * ctx_tgt = this->params.ctx_tgt; + auto * ctx_dft = this->params.ctx_dft; + + const size_t row_bytes = (size_t) n_embd * sizeof(float); + + common_batch_clear(batch); + + for (int k = 0; k < n_tokens; ++k) { + common_batch_add(batch, batch_in.token[k], batch_in.pos[k], { batch_in.seq_id[k][0] }, 0); + } + + // shift the tgt embeddings to the right by one position + // assumes that the tokens in the batch are sequential for each sequence + // i.e. we cannot have seq_id like this: [0, 0, 0, 1, 1, 0, 1, 1] + // ^--- this is a problem + // TODO:this is generally true, but would be nice to assert it + { + const float * h_tgt = llama_get_embeddings_pre_norm(ctx_tgt); + std::memcpy(batch.embd + (size_t) 1 * n_embd, h_tgt, row_bytes * (n_tokens-1)); + + //{ + // // string with seq_ids in the batch + // std::stringstream ss; + // for (int i = 0; i < n_tokens; ++i) { + // ss << batch_in.seq_id[i][0] << ","; + // } + // LOG_WRN("%s: batch_in.seq_id = %s\n", __func__, ss.str().c_str()); + //} + } + + // fill the pending embeddings from a previous run + auto set_h = [&](int idx, const float * h_row) { + std::memcpy(batch.embd + (size_t) idx * n_embd, h_row, row_bytes); + }; + + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + if (i_batch_beg[seq_id] < 0) { + continue; + } + + set_h(i_batch_beg[seq_id], pending_h[seq_id].data()); + } + + const int32_t rc = llama_decode(ctx_dft, batch); + if (rc != 0) { + LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n", __func__, (int) rc, (int) batch_in.pos[0]); + return false; + } + + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + if (i_batch_end[seq_id] < 0) { + continue; + } + + const float * h_last = llama_get_embeddings_pre_norm_ith(ctx_tgt, i_batch_end[seq_id]); + std::memcpy(pending_h[seq_id].data(), h_last, row_bytes); + } + + return true; + } + + void draft(common_speculative_draft_params_vec & dparams) override { + auto & ctx_dft = params.ctx_dft; + + common_batch_clear(batch); + + // keep track of which sequences are still drafting + int n_drafting = 0; + std::vector drafting(n_seq); + + const float * h_row = nullptr; + const size_t row_bytes = (size_t) n_embd * sizeof(float); + + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + auto & dp = dparams[seq_id]; + + if (!dp.drafting) { + continue; + } + + n_drafting++; + drafting[seq_id] = true; + common_sampler_reset(smpls[seq_id].get()); + + common_batch_add(batch, dp.id_last, dp.n_past, { seq_id }, true); + + h_row = pending_h[seq_id].data(); + std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes); + } + + int ret = llama_decode(ctx_dft, batch); + if (ret != 0) { + LOG_WRN("%s: llama_decode returned %d\n", __func__, ret); + return; + } + + int i = 0; + + while (n_drafting > 0) { + int i_batch = 0; + + common_batch_clear(batch); + + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + if (!drafting[seq_id]) { + continue; + } + + auto * smpl = smpls[seq_id].get(); + + common_sampler_sample(smpl, ctx_dft, i_batch, true); + h_row = llama_get_embeddings_pre_norm_ith(ctx_dft, i_batch); + ++i_batch; + + const auto * cur_p = common_sampler_get_candidates(smpl, true); + + for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { + LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", + seq_id, k, i, cur_p->data[k].id, cur_p->data[k].p, + common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str()); + } + + // add drafted token for each sequence + const llama_token id = cur_p->data[0].id; + + // only collect very high-confidence draft tokens + if (cur_p->data[0].p < params.p_min) { + drafting[seq_id] = false; + n_drafting--; + + continue; + } + + common_sampler_accept(smpl, id, true); + + auto & dp = dparams.at(seq_id); + auto & result = *dp.result; + + result.push_back(id); + + if ((params.n_max <= (int) result.size()) || + (dp.n_max > 0 && dp.n_max <= (int) result.size())) { + drafting[seq_id] = false; + n_drafting--; + continue; + } + + common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true); + std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes); + } + + if (batch.n_tokens == 0) { + break; + } + + // evaluate the drafted tokens on the draft model + ret = llama_decode(ctx_dft, batch); + if (ret != 0) { + LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret); + break; + } + + ++i; + } + + for (auto & dp : dparams) { + if (!dp.drafting) { + continue; + } + + if (dp.result->size() < (size_t) params.n_min) { + dp.result->clear(); + } + } + } + + void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { + } +}; + // state of self-speculation (simple implementation, not ngram-map) struct common_speculative_state_ngram_simple : public common_speculative_impl { common_params_speculative_ngram_map params; @@ -818,6 +1105,7 @@ std::string common_speculative_type_to_str(enum common_speculative_type type) { case COMMON_SPECULATIVE_TYPE_NONE: return "none"; case COMMON_SPECULATIVE_TYPE_DRAFT: return "draft"; case COMMON_SPECULATIVE_TYPE_EAGLE3: return "eagle3"; + case COMMON_SPECULATIVE_TYPE_MTP: return "mtp"; case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram_simple"; case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram_map_k"; case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram_map_k4v"; @@ -843,6 +1131,7 @@ common_speculative * common_speculative_init(common_params_speculative & params, { bool has_draft = !params.draft.mparams.path.empty(); bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3 + bool has_mtp = (params.type == COMMON_SPECULATIVE_TYPE_MTP) && params.draft.ctx_dft != nullptr; bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE); bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE); @@ -875,6 +1164,9 @@ common_speculative * common_speculative_init(common_params_speculative & params, if (has_draft_eagle3) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_EAGLE3, params)); } + if (has_mtp) { + configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_MTP, params)); + } } std::vector> impls = {}; @@ -892,6 +1184,10 @@ common_speculative * common_speculative_init(common_params_speculative & params, impls.push_back(std::make_unique(config.params, n_seq)); break; } + case COMMON_SPECULATIVE_TYPE_MTP: { + impls.push_back(std::make_unique(config.params, n_seq)); + break; + } case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: { common_ngram_map ngram_map = get_common_ngram_map(config.type, config.params.ngram_simple); diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index fb1f5dd4473..177058b8157 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -5518,13 +5518,70 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_dimension_sections(self._QWEN35_DEFAULT_MROPE_SECTION) +class _Qwen35MtpMixin: + """Shared MTP wiring for Qwen3.5/3.6 text variants. The HF config carries + the MTP block under `mtp_num_hidden_layers` and the tensors under + `mtp.*`; we extend block_count, emit the nextn metadata key, and remap + `mtp.*` to the standard layer-indexed nextn naming so the existing + tensor_map handles them.""" + + # Class-level annotations so the type checker understands the attributes + # available on the concrete subclasses in the MRO + hparams: dict[str, Any] + model_arch: gguf.MODEL_ARCH + gguf_writer: gguf.GGUFWriter + block_count: int + tensor_map: gguf.TensorNameMap + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.block_count = self.hparams["num_hidden_layers"] + self.hparams.get("mtp_num_hidden_layers", 0) + self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + + def set_gguf_parameters(self): + super().set_gguf_parameters() # ty: ignore[unresolved-attribute] + if (n := self.hparams.get("mtp_num_hidden_layers", 0)) > 0: + self.gguf_writer.add_nextn_predict_layers(n) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # Multimodal Qwen3.5/3.6 wrap the text model under `model.language_model.*`. + if name.startswith("model.language_model."): + name = "model." + name[len("model.language_model."):] + elif name.startswith("language_model."): + name = name[len("language_model."):] + + # Remap MTP block tensors to llama.cpp's layer-indexed nextn naming. + # HF: mtp.layers.0.* (transformer block at MTP slot 0) + # mtp.fc / mtp.pre_fc_norm_embedding / mtp.pre_fc_norm_hidden / mtp.norm + if name.startswith("mtp."): + n_layer = self.hparams["num_hidden_layers"] + if name.find("layers.") != -1: + assert bid is not None + name = name.replace(f"mtp.layers.{bid}", f"model.layers.{bid + n_layer}") + else: + remapper = { + "mtp.fc": "model.layers.{bid}.eh_proj", + "mtp.pre_fc_norm_embedding": "model.layers.{bid}.enorm", + "mtp.pre_fc_norm_hidden": "model.layers.{bid}.hnorm", + "mtp.norm": "model.layers.{bid}.shared_head.norm", + } + stem = Path(name).stem + suffix = Path(name).suffix + tmpl = remapper[stem] + suffix + for b in range(n_layer, self.block_count): + yield from super().modify_tensors(data_torch, tmpl.format(bid=b), b) # ty: ignore[unresolved-attribute] + return + + yield from super().modify_tensors(data_torch, name, bid) # ty: ignore[unresolved-attribute] + + @ModelBase.register("Qwen3_5ForConditionalGeneration", "Qwen3_5ForCausalLM") -class Qwen3_5TextModel(_Qwen35MRopeMixin, _LinearAttentionVReorderBase): +class Qwen3_5TextModel(_Qwen35MtpMixin, _Qwen35MRopeMixin, _LinearAttentionVReorderBase): model_arch = gguf.MODEL_ARCH.QWEN35 @ModelBase.register("Qwen3_5MoeForConditionalGeneration", "Qwen3_5MoeForCausalLM") -class Qwen3_5MoeTextModel(_Qwen35MRopeMixin, _LinearAttentionVReorderBase): +class Qwen3_5MoeTextModel(_Qwen35MtpMixin, _Qwen35MRopeMixin, _LinearAttentionVReorderBase): model_arch = gguf.MODEL_ARCH.QWEN35MOE diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 308ebe1f4a1..4209470cb8c 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -2109,7 +2109,14 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_NORM, MODEL_TENSOR.SSM_BETA, MODEL_TENSOR.SSM_ALPHA, - MODEL_TENSOR.SSM_OUT + MODEL_TENSOR.SSM_OUT, + # NextN/MTP tensors - preserved but unused + MODEL_TENSOR.NEXTN_EH_PROJ, + MODEL_TENSOR.NEXTN_EMBED_TOKENS, + MODEL_TENSOR.NEXTN_ENORM, + MODEL_TENSOR.NEXTN_HNORM, + MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD, + MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM, ], MODEL_ARCH.QWEN35MOE: [ MODEL_TENSOR.TOKEN_EMBD, @@ -2140,7 +2147,14 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_NORM, MODEL_TENSOR.SSM_BETA, MODEL_TENSOR.SSM_ALPHA, - MODEL_TENSOR.SSM_OUT + MODEL_TENSOR.SSM_OUT, + # NextN/MTP tensors - preserved but unused + MODEL_TENSOR.NEXTN_EH_PROJ, + MODEL_TENSOR.NEXTN_EMBED_TOKENS, + MODEL_TENSOR.NEXTN_ENORM, + MODEL_TENSOR.NEXTN_HNORM, + MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD, + MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM, ], MODEL_ARCH.PLAMO: [ MODEL_TENSOR.TOKEN_EMBD, diff --git a/include/llama.h b/include/llama.h index 308e8ba9dbd..1b896944735 100644 --- a/include/llama.h +++ b/include/llama.h @@ -310,6 +310,9 @@ extern "C" { // override key-value pairs of the model meta data const struct llama_model_kv_override * kv_overrides; + // override architecture from GGUF (e.g. load the MTP head of a Qwen3.5 GGUF as "qwen35_mtp") + const char * override_arch; + // Keep the booleans together to avoid misalignment during copy-by-value. bool vocab_only; // only load the vocabulary, no weights bool use_mmap; // use mmap if possible diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 59dde99e362..794666d09a4 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -41,6 +41,8 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" }, { LLM_ARCH_QWEN35, "qwen35" }, { LLM_ARCH_QWEN35MOE, "qwen35moe" }, + { LLM_ARCH_QWEN35_MTP, "qwen35_mtp" }, + { LLM_ARCH_QWEN35MOE_MTP, "qwen35moe_mtp" }, { LLM_ARCH_PHI2, "phi2" }, { LLM_ARCH_PHI3, "phi3" }, { LLM_ARCH_PHIMOE, "phimoe" }, @@ -757,14 +759,15 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - // NextN/MTP tensors are currently ignored (reserved for future MTP support) - // These tensors only exist in the last layer(s) and are treated as output tensors - {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, - {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + // NextN/MTP tensors are stored per-block (blk.%d.nextn.*) even though only the + // last nextn_predict_layers blocks carry them. Classify as LAYER_REPEATING so + // the model loader doesn't fault on the block index. + {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // Nemotron 3 Super {LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, diff --git a/src/llama-arch.h b/src/llama-arch.h index e37d548c98e..71c2ca6e6b3 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -45,6 +45,8 @@ enum llm_arch { LLM_ARCH_QWEN3VLMOE, LLM_ARCH_QWEN35, LLM_ARCH_QWEN35MOE, + LLM_ARCH_QWEN35_MTP, + LLM_ARCH_QWEN35MOE_MTP, LLM_ARCH_PHI2, LLM_ARCH_PHI3, LLM_ARCH_PHIMOE, diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 3d9714ab166..aea8a0a4e81 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -49,6 +49,7 @@ llama_context::llama_context( cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast; cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow; cparams.embeddings = params.embeddings; + cparams.embeddings_pre_norm = false; cparams.offload_kqv = params.offload_kqv; cparams.no_perf = params.no_perf; cparams.pooling_type = params.pooling_type; @@ -860,6 +861,33 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) { return it->second.data(); } +float * llama_context::get_embeddings_pre_norm() { + output_reorder(); + + return embd_pre_norm.data; +} + +float * llama_context::get_embeddings_pre_norm_ith(int32_t i) { + output_reorder(); + + try { + if (embd_pre_norm.data == nullptr) { + throw std::runtime_error("no pre-norm embeddings"); + } + + const int64_t j = output_resolve_row(i); + const uint32_t n_embd = model.hparams.n_embd; + return embd_pre_norm.data + j*n_embd; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid pre-norm embeddings id %d, reason: %s\n", __func__, i, err.what()); +#ifndef NDEBUG + GGML_ABORT("fatal error"); +#else + return nullptr; +#endif + } +} + llama_token llama_context::get_sampled_token_ith(int32_t idx) { output_reorder(); @@ -1040,6 +1068,12 @@ void llama_context::set_embeddings(bool value) { //sched_need_reserve = true; } +void llama_context::set_embeddings_pre_norm(bool value) { + LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); + + cparams.embeddings_pre_norm = value; +} + void llama_context::set_causal_attn(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); @@ -1241,7 +1275,9 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll } int llama_context::encode(const llama_batch & batch_inp) { - GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT + // MTP hook batches carry both token (next-token id) and embd (h_pre_norm row), + // so accept either present rather than requiring exactly one. + GGML_ASSERT(batch_inp.token || batch_inp.embd); if (batch_inp.n_tokens == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); @@ -1312,8 +1348,9 @@ int llama_context::encode(const llama_batch & batch_inp) { } } - auto * t_logits = res->get_logits(); - auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); + auto * t_logits = res->get_logits(); + auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); + auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm() : nullptr; // extract logits if (logits.data && t_logits) { @@ -1379,6 +1416,16 @@ int llama_context::encode(const llama_batch & batch_inp) { } } + // extract pre-norm embeddings (hidden state before the final output norm) + if (embd_pre_norm.data && t_h_pre_norm && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { + ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm); + GGML_ASSERT(backend_h != nullptr); + + const uint32_t n_embd = hparams.n_embd; + GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_pre_norm.size); + ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm.data, 0, n_tokens*n_embd*sizeof(float)); + } + // TODO: hacky solution if (model.arch == LLM_ARCH_T5 && t_embd) { //cross.t_embd = t_embd; @@ -1531,7 +1578,9 @@ static bool needs_raw_logits(const llama_ubatch & ubatch, const std::mapget_logits(); - auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; + auto * t_logits = res->get_logits(); + auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; + auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm() : nullptr; if (t_embd && res->get_embd_pooled()) { t_embd = res->get_embd_pooled(); @@ -1809,6 +1859,20 @@ int llama_context::decode(const llama_batch & batch_inp) { } } + // extract pre-norm embeddings (hidden state before the final output norm) + // only meaningful in LLAMA_POOLING_TYPE_NONE (per-token); other pooling modes are ignored. + if (embd_pre_norm.data && t_h_pre_norm && n_outputs > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { + ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm); + GGML_ASSERT(backend_h != nullptr); + + const uint32_t n_embd = hparams.n_embd; + float * embd_pre_norm_out = embd_pre_norm.data + n_outputs_prev*n_embd; + + GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_pre_norm.size); + ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm_out, 0, n_outputs*n_embd*sizeof(float)); + } + // Copy backend sampling output if this ubatch produced any sampling tensors. if (has_samplers && (!res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty())) { const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev); @@ -1893,10 +1957,12 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { const auto n_batch = cparams.n_batch; const auto n_vocab = vocab.n_tokens(); + const auto n_embd = hparams.n_embd; const auto n_embd_out = hparams.n_embd_out(); - bool has_logits = true; - bool has_embd = cparams.embeddings; + bool has_logits = true; + bool has_embd = cparams.embeddings; + bool has_embd_pre_norm = cparams.embeddings_pre_norm; // TODO: hacky enc-dec support if (model.arch == LLM_ARCH_T5) { @@ -1908,8 +1974,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { size_t backend_float_count = 0; size_t backend_token_count = 0; - logits.size = has_logits ? n_vocab*n_outputs_max : 0; - embd.size = has_embd ? n_embd_out*n_outputs_max : 0; + logits.size = has_logits ? n_vocab*n_outputs_max : 0; + embd.size = has_embd ? n_embd_out*n_outputs_max : 0; + embd_pre_norm.size = has_embd_pre_norm ? n_embd*n_outputs_max : 0; // Allocate backend sampling output buffers if there are backend samplers configured. const bool has_sampling = !sampling.samplers.empty(); @@ -1925,8 +1992,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0; const size_t new_size = - (logits.size + embd.size + backend_float_count) * sizeof(float) + - ( backend_token_count) * sizeof(llama_token); + (logits.size + embd.size + embd_pre_norm.size + backend_float_count) * sizeof(float) + + ( backend_token_count) * sizeof(llama_token); // alloc only when more than the current capacity is required // TODO: also consider shrinking the buffer @@ -1942,6 +2009,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { buf_output = nullptr; logits.data = nullptr; embd.data = nullptr; + embd_pre_norm.data = nullptr; } auto * buft = ggml_backend_cpu_buffer_type(); @@ -1970,6 +2038,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { embd = has_embd ? buffer_view{(float *) (base + offset), embd.size} : buffer_view{nullptr, 0}; offset += embd.size * sizeof(float); + embd_pre_norm = has_embd_pre_norm ? buffer_view{(float *) (base + offset), embd_pre_norm.size} : buffer_view{nullptr, 0}; + offset += embd_pre_norm.size * sizeof(float); + if (has_sampling) { sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; offset += sampling.logits.size * sizeof(float); @@ -2034,6 +2105,12 @@ void llama_context::output_reorder() { } } + if (embd_pre_norm.size > 0) { + for (uint64_t k = 0; k < n_embd; k++) { + std::swap(embd_pre_norm.data[i0*n_embd + k], embd_pre_norm.data[i1*n_embd + k]); + } + } + if (!sampling.samplers.empty()) { assert(sampling.logits.size > 0); assert(sampling.probs.size > 0); @@ -3436,6 +3513,22 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) { return ctx->get_embeddings_seq(seq_id); } +void llama_set_embeddings_pre_norm(llama_context * ctx, bool value) { + ctx->set_embeddings_pre_norm(value); +} + +float * llama_get_embeddings_pre_norm(llama_context * ctx) { + ctx->synchronize(); + + return ctx->get_embeddings_pre_norm(); +} + +float * llama_get_embeddings_pre_norm_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return ctx->get_embeddings_pre_norm_ith(i); +} + bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) { return ctx->set_sampler(seq_id, smpl); } diff --git a/src/llama-context.h b/src/llama-context.h index 92d1b0cf95a..e16ac4c618b 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -84,6 +84,9 @@ struct llama_context { float * get_embeddings_ith(int32_t i); float * get_embeddings_seq(llama_seq_id seq_id); + float * get_embeddings_pre_norm(); + float * get_embeddings_pre_norm_ith(int32_t i); + llama_token * get_sampled_tokens() const; llama_token get_sampled_token_ith(int32_t idx); @@ -107,6 +110,7 @@ struct llama_context { void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data); void set_embeddings (bool value); + void set_embeddings_pre_norm(bool value); void set_causal_attn(bool value); void set_warmup(bool value); @@ -278,6 +282,11 @@ struct llama_context { // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE buffer_view embd = {nullptr, 0}; + // hidden state before the final output norm (2-dimensional array: [n_outputs][n_embd]) + // populated only when cparams.embeddings_pre_norm is enabled and the model graph + // sets llm_graph_result::t_h_pre_norm + buffer_view embd_pre_norm = {nullptr, 0}; + struct sampling_info { // !samplers.empty() to check if any samplers are active std::map samplers; diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 9d359474132..1e4e9e29ed8 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -27,6 +27,7 @@ struct llama_cparams { float yarn_beta_slow; bool embeddings; + bool embeddings_pre_norm; // also extract the hidden state before the final output norm bool causal_attn; bool offload_kqv; bool flash_attn; diff --git a/src/llama-ext.h b/src/llama-ext.h index 8ce29d217cb..11f1986676a 100644 --- a/src/llama-ext.h +++ b/src/llama-ext.h @@ -88,3 +88,19 @@ LLAMA_API int32_t llama_model_n_devices(const struct llama_model * model); LLAMA_API ggml_backend_dev_t llama_model_get_device(const struct llama_model * model, int i); LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_context * ctx); + +// +// pre-norm embeddings (hidden state before the final output norm) +// + +// mirrors: +// LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings); +LLAMA_API void llama_set_embeddings_pre_norm(struct llama_context * ctx, bool value); + +// mirrors: +// LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); +LLAMA_API float * llama_get_embeddings_pre_norm(struct llama_context * ctx); + +// mirrors: +// LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); +LLAMA_API float * llama_get_embeddings_pre_norm_ith(struct llama_context * ctx, int32_t i); diff --git a/src/llama-graph.h b/src/llama-graph.h index 5cb1756c6a9..d3cd69a674c 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -644,6 +644,7 @@ class llm_graph_result { ggml_tensor * get_logits() const { return t_logits; } ggml_tensor * get_embd() const { return t_embd; } ggml_tensor * get_embd_pooled() const { return t_embd_pooled; } + ggml_tensor * get_h_pre_norm() const { return t_h_pre_norm; } ggml_cgraph * get_gf() const { return gf; } ggml_context * get_ctx() const { return ctx_compute.get(); } @@ -672,6 +673,7 @@ class llm_graph_result { ggml_tensor * t_logits = nullptr; ggml_tensor * t_embd = nullptr; ggml_tensor * t_embd_pooled = nullptr; + ggml_tensor * t_h_pre_norm = nullptr; // [n_embd, n_outputs] hidden state before final output norm std::map t_sampled_logits; std::map t_candidates; diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 002d15d415f..2239309c8fb 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -229,6 +229,12 @@ uint32_t llama_hparams::n_embd_head_v_mla() const { } bool llama_hparams::has_kv(uint32_t il) const { + if (kv_only_nextn) { + // MTP head: only the trailing nextn_predict_layers blocks own a KV cache; + // the leading trunk blocks are not executed in this graph. + return nextn_predict_layers > 0 && il >= (n_layer - nextn_predict_layers); + } + if (n_layer_kv_from_start >= 0) { if (il < (uint32_t) n_layer_kv_from_start) { return true; diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 0160a89caa2..e2d051edc6c 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -92,6 +92,8 @@ struct llama_hparams { uint32_t moe_latent_size = 0; uint32_t nextn_predict_layers = 0; + bool kv_only_nextn = false; // if true, only the last nextn_predict_layers blocks have a KV cache (MTP head arches) + float f_norm_eps; float f_norm_rms_eps; float f_norm_group_eps; diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index 4e65a45a50d..c645d0785ab 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -1312,9 +1312,16 @@ struct ggml_tensor * llama_model_loader::create_tensor_as_view(struct ggml_conte return tensor; } -void llama_model_loader::done_getting_tensors() const { - if (n_created != n_tensors) { - throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created)); +void llama_model_loader::done_getting_tensors(bool partial) const { + if (n_created > n_tensors) { + throw std::runtime_error(format("%s: too many tensors created; expected %d, got %d", __func__, n_tensors, n_created)); + } + if (n_created < n_tensors) { + if (!partial) { + throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created)); + } + LLAMA_LOG_INFO("%s: partial load — used %d of %d tensors in the file (rest belong to a sibling model on the same .gguf)\n", + __func__, n_created, n_tensors); } if (n_tensors_moved > 0) { LLAMA_LOG_DEBUG("%s: tensor '%s' (%s) (and %zu others) cannot be used with preferred buffer type %s, using %s instead\n", diff --git a/src/llama-model-loader.h b/src/llama-model-loader.h index 7b3d6703c03..c476026d3e5 100644 --- a/src/llama-model-loader.h +++ b/src/llama-model-loader.h @@ -184,7 +184,7 @@ struct llama_model_loader { struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::initializer_list & ne, size_t offset, bool required = true); - void done_getting_tensors() const; + void done_getting_tensors(bool partial = false) const; void init_mappings(bool prefetch = true, llama_mlocks * mlock_mmaps = nullptr); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 9d011ff3464..ab1166af235 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -276,6 +276,10 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params return new llama_model_qwen35(params); case LLM_ARCH_QWEN35MOE: return new llama_model_qwen35moe(params); + case LLM_ARCH_QWEN35_MTP: + return new llama_model_qwen35_mtp(params); + case LLM_ARCH_QWEN35MOE_MTP: + return new llama_model_qwen35moe_mtp(params); case LLM_ARCH_MISTRAL3: return new llama_model_mistral3(params); case LLM_ARCH_MIMO2: @@ -309,6 +313,15 @@ llama_model * llama_model_create(llama_model_loader & ml, const llama_model_para if (arch == LLM_ARCH_UNKNOWN) { throw std::runtime_error("unknown model architecture: '" + ml.get_arch_name() + "'"); } + if (params.override_arch != nullptr && params.override_arch[0] != '\0') { + const llm_arch override = llm_arch_from_string(params.override_arch); + if (override == LLM_ARCH_UNKNOWN) { + throw std::runtime_error(std::string("unknown override architecture: '") + params.override_arch + "'"); + } + LLAMA_LOG_INFO("%s: overriding architecture %s -> %s\n", + __func__, llm_arch_name(arch), llm_arch_name(override)); + arch = override; + } return llama_model_create(arch, params); } @@ -1400,7 +1413,8 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) { } } - ml.done_getting_tensors(); + const bool partial_load = (arch == LLM_ARCH_QWEN35_MTP || arch == LLM_ARCH_QWEN35MOE_MTP); + ml.done_getting_tensors(partial_load); // populate tensors_by_name for (auto & [_, ctx_ptr] : ml.ctx_map) { @@ -2093,6 +2107,7 @@ llama_model_params llama_model_default_params() { /*.progress_callback =*/ nullptr, /*.progress_callback_user_data =*/ nullptr, /*.kv_overrides =*/ nullptr, + /*.override_arch =*/ nullptr, /*.vocab_only =*/ false, /*.use_mmap =*/ true, /*.use_direct_io =*/ false, @@ -2317,6 +2332,8 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_QWEN3VLMOE: case LLM_ARCH_QWEN35: case LLM_ARCH_QWEN35MOE: + case LLM_ARCH_QWEN35_MTP: + case LLM_ARCH_QWEN35MOE_MTP: return LLAMA_ROPE_TYPE_IMROPE; case LLM_ARCH_GLM4: diff --git a/src/models/models.h b/src/models/models.h index 6d5f18a8e20..1f04d313d13 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -1785,6 +1785,32 @@ struct llama_model_qwen35moe : public llama_model_base { }; +struct llama_model_qwen35_mtp : public llama_model_base { + llama_model_qwen35_mtp(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_qwen35moe_mtp : public llama_model_base { + llama_model_qwen35moe_mtp(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; + + struct llama_model_mistral3 : public llama_model_base { llama_model_mistral3(const struct llama_model_params & params) : llama_model_base(params) {} void load_arch_hparams(llama_model_loader & ml) override; diff --git a/src/models/qwen35-mtp.cpp b/src/models/qwen35-mtp.cpp new file mode 100644 index 00000000000..83039e98db5 --- /dev/null +++ b/src/models/qwen35-mtp.cpp @@ -0,0 +1,207 @@ +#include "models.h" + +void llama_model_qwen35_mtp::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); + + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35_MTP requires nextn_predict_layers > 0"); + GGML_ASSERT(hparams.nextn_predict_layers <= hparams.n_layer); + + // only the MTP layers get a KV cache, trunk layers are skipped. + hparams.kv_only_nextn = true; + hparams.n_layer_kv_from_start = -1; + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = false; + } + + type = LLM_TYPE_UNKNOWN; +} + +void llama_model_qwen35_mtp::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, TENSOR_NOT_REQUIRED); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + if (output == nullptr) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + const uint32_t n_main = n_layer - hparams.nextn_predict_layers; + for (int i = 0; i < n_layer; ++i) { + if (static_cast(i) < n_main) { + continue; // trunk layer — owned by the sibling QWEN35 model + } + + auto & layer = layers[i]; + + // MTP block looks like a full-attention Qwen3.5 decoder block. + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // NextN-specific tensors that define the MTP block. + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, 0); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, 0); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, 0); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); + } +} + +std::unique_ptr llama_model_qwen35_mtp::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +// LLM_ARCH_QWEN35_MTP draft head for Qwen3.5/3.6 dense series +llama_model_qwen35_mtp::graph::graph(const llama_model & model, const llm_graph_params & params) + : llm_graph_context(params) { + GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35_MTP requires nextn_predict_layers > 0"); + GGML_ASSERT(hparams.nextn_predict_layers == 1 && "QWEN35_MTP currently only supports a single MTP block"); + + const int64_t n_embd_head = hparams.n_embd_head_v(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + + // The MTP block lives at the source file's original layer index. + const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers; + const auto & layer = model.layers[il]; + + GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); + GGML_ASSERT(layer.nextn.enorm && "MTP block missing nextn.enorm"); + GGML_ASSERT(layer.nextn.hnorm && "MTP block missing nextn.hnorm"); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + auto inp = std::make_unique(hparams.n_embd); + + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_input(inp->tokens); + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + ggml_set_input(inp->embd); + ggml_set_name(inp->embd, "mtp_h_input"); + + ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; + + ggml_tensor * h_input = inp->embd; + ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); + cb(tok_embd, "mtp_tok_embd", il); + + res->add_input(std::move(inp)); + + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); + cb(h_norm, "mtp_hnorm", il); + + ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il); + cb(e_norm, "mtp_enorm", il); + + ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, /*dim=*/ 0); + cb(concat, "mtp_concat", il); + + ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat); + cb(cur, "mtp_eh_proj", il); + + ggml_tensor * inpSA = cur; + + cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_norm", il); + + ggml_tensor * Qcur_full = build_lora_mm(layer.wq, cur, layer.wq_s); + cb(Qcur_full, "mtp_Qcur_full", il); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + 0); + Qcur = build_norm(Qcur, layer.attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "mtp_Qcur_normed", il); + + ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + ggml_element_size(Qcur_full) * n_embd_head); + gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); + cb(gate, "mtp_gate", il); + + ggml_tensor * Kcur = build_lora_mm(layer.wk, cur, layer.wk_s); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_norm(Kcur, layer.attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "mtp_Kcur_normed", il); + + ggml_tensor * Vcur = build_lora_mm(layer.wv, cur, layer.wv_s); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + cb(Vcur, "mtp_Vcur", il); + + Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + const float kq_scale = hparams.f_attention_scale == 0.0f + ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + cur = build_attn(inp_attn, + nullptr, nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "mtp_attn_pregate", il); + + cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate)); + cur = build_lora_mm(layer.wo, cur, layer.wo_s); + cb(cur, "mtp_attn_out", il); + + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "mtp_attn_residual", il); + + ggml_tensor * ffn_residual = cur; + cur = build_norm(cur, layer.attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_post_norm", il); + + cur = build_ffn(cur, + layer.ffn_up, nullptr, layer.ffn_up_s, + layer.ffn_gate, nullptr, layer.ffn_gate_s, + layer.ffn_down, nullptr, layer.ffn_down_s, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "mtp_ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_residual); + cb(cur, "mtp_post_ffn", il); + + // Pre-norm hidden state: used by the AR draft loop to seed the next MTP step. + // (In the trunk graph this is `t_h_pre_norm`; the MTP head reuses the same slot.) + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + + ggml_tensor * head_norm_w = layer.nextn.shared_head_norm + ? layer.nextn.shared_head_norm + : model.output_norm; + GGML_ASSERT(head_norm_w && "QWEN35_MTP: missing both nextn.shared_head_norm and output_norm"); + cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1); + cb(cur, "mtp_shared_head_norm", -1); + + ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output; + GGML_ASSERT(head_w && "QWEN35_MTP: missing LM head (nextn.shared_head_head or model.output)"); + cur = build_lora_mm(head_w, cur); + cb(cur, "result_output", -1); + + res->t_logits = cur; + ggml_build_forward_expand(gf, cur); +} diff --git a/src/models/qwen35.cpp b/src/models/qwen35.cpp index f276be61ba8..79fdd8f679b 100644 --- a/src/models/qwen35.cpp +++ b/src/models/qwen35.cpp @@ -12,16 +12,23 @@ void llama_model_qwen35::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - // Mark recurrent layers (linear attention layers) + // NextN/MTP (Qwen3.5/3.6): extra decoder block appended beyond the main stack + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + + // Mark recurrent layers (linear attention layers). MTP layers are dense + // attention-only and must be flagged non-recurrent. { + const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; uint32_t full_attn_interval = 4; ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + hparams.recurrent_layer_arr[i] = (i < n_main) && ((i + 1) % full_attn_interval != 0); } } - switch (hparams.n_layer) { + switch (hparams.n_layer - hparams.nextn_predict_layers) { case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_8B : LLM_TYPE_2B; break; case 32: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_9B; break; case 64: type = LLM_TYPE_27B; break; @@ -83,6 +90,16 @@ void llama_model_qwen35::load_arch_tensors(llama_model_loader &) { layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // NextN/MTP tensors (preserved but unused) - only bound on MTP layers + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, TENSOR_NOT_REQUIRED); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); + } } } @@ -111,7 +128,9 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_out_ids = build_inp_out_ids(); - for (int il = 0; il < n_layer; ++il) { + // MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass. + const int n_transformer_layers = n_layer - (int) hparams.nextn_predict_layers; + for (int il = 0; il < n_transformer_layers; ++il) { ggml_tensor * inpSA = inpL; cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); @@ -128,7 +147,7 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); } - if (il == n_layer - 1 && inp_out_ids) { + if (il == n_transformer_layers - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -160,6 +179,9 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para } cur = inpL; + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + // Final norm cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); diff --git a/src/models/qwen35moe-mtp.cpp b/src/models/qwen35moe-mtp.cpp new file mode 100644 index 00000000000..9f662213bee --- /dev/null +++ b/src/models/qwen35moe-mtp.cpp @@ -0,0 +1,252 @@ +#include "models.h" + +void llama_model_qwen35moe_mtp::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); + + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35MOE_MTP requires nextn_predict_layers > 0"); + GGML_ASSERT(hparams.nextn_predict_layers <= hparams.n_layer); + GGML_ASSERT(hparams.n_expert > 0 && "QWEN35MOE_MTP requires n_expert > 0"); + + // only the MTP layers get a KV cache, trunk layers are skipped. + hparams.kv_only_nextn = true; + hparams.n_layer_kv_from_start = -1; + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = false; + } + + type = LLM_TYPE_UNKNOWN; +} + +void llama_model_qwen35moe_mtp::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, TENSOR_NOT_REQUIRED); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + if (output == nullptr) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; + + const uint32_t n_main = n_layer - hparams.nextn_predict_layers; + for (int i = 0; i < n_layer; ++i) { + if (static_cast(i) < n_main) { + continue; // trunk layer — owned by the sibling QWEN35MOE model + } + + auto & layer = layers[i]; + + // MTP block looks like a full-attention Qwen3.5 decoder block with MoE FFN. + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + + // Routed experts + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); + create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); + + // Shared experts + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0); + + // NextN-specific tensors that define the MTP block. + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, 0); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, 0); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, 0); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); + } +} + +std::unique_ptr llama_model_qwen35moe_mtp::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +// LLM_ARCH_QWEN35MOE_MTP draft head for Qwen3.5/3.6 MoE +llama_model_qwen35moe_mtp::graph::graph(const llama_model & model, const llm_graph_params & params) + : llm_graph_context(params) { + GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35MOE_MTP requires nextn_predict_layers > 0"); + GGML_ASSERT(hparams.nextn_predict_layers == 1 && "QWEN35MOE_MTP currently only supports a single MTP block"); + + const int64_t n_embd_head = hparams.n_embd_head_v(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + + const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers; + const auto & layer = model.layers[il]; + + GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); + GGML_ASSERT(layer.nextn.enorm && "MTP block missing nextn.enorm"); + GGML_ASSERT(layer.nextn.hnorm && "MTP block missing nextn.hnorm"); + GGML_ASSERT(layer.ffn_gate_inp && "MTP block missing ffn_gate_inp"); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + auto inp = std::make_unique(hparams.n_embd); + + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_input(inp->tokens); + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + ggml_set_input(inp->embd); + ggml_set_name(inp->embd, "mtp_h_input"); + + ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; + + ggml_tensor * h_input = inp->embd; + ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); + cb(tok_embd, "mtp_tok_embd", il); + + res->add_input(std::move(inp)); + + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); + cb(h_norm, "mtp_hnorm", il); + + ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il); + cb(e_norm, "mtp_enorm", il); + + ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, /*dim=*/ 0); + cb(concat, "mtp_concat", il); + + ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat); + cb(cur, "mtp_eh_proj", il); + + ggml_tensor * inpSA = cur; + + cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_norm", il); + + ggml_tensor * Qcur_full = build_lora_mm(layer.wq, cur, layer.wq_s); + cb(Qcur_full, "mtp_Qcur_full", il); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + 0); + Qcur = build_norm(Qcur, layer.attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "mtp_Qcur_normed", il); + + ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + ggml_element_size(Qcur_full) * n_embd_head); + gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); + cb(gate, "mtp_gate", il); + + ggml_tensor * Kcur = build_lora_mm(layer.wk, cur, layer.wk_s); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_norm(Kcur, layer.attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "mtp_Kcur_normed", il); + + ggml_tensor * Vcur = build_lora_mm(layer.wv, cur, layer.wv_s); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + cb(Vcur, "mtp_Vcur", il); + + Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + const float kq_scale = hparams.f_attention_scale == 0.0f + ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + cur = build_attn(inp_attn, + nullptr, nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "mtp_attn_pregate", il); + + cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate)); + cur = build_lora_mm(layer.wo, cur, layer.wo_s); + cb(cur, "mtp_attn_out", il); + + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "mtp_attn_residual", il); + + ggml_tensor * ffn_residual = cur; + cur = build_norm(cur, layer.attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_post_norm", il); + + // MoE FFN — routed experts plus gated shared expert (mirrors qwen35moe). + ggml_tensor * moe_out = + build_moe_ffn(cur, + layer.ffn_gate_inp, + layer.ffn_up_exps, + layer.ffn_gate_exps, + layer.ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + hparams.expert_weights_scale, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il, + nullptr, layer.ffn_gate_up_exps, + layer.ffn_up_exps_s, + layer.ffn_gate_exps_s, + layer.ffn_down_exps_s); + cb(moe_out, "mtp_ffn_moe_out", il); + + if (layer.ffn_up_shexp != nullptr) { + ggml_tensor * ffn_shexp = + build_ffn(cur, + layer.ffn_up_shexp, nullptr, layer.ffn_up_shexp_s, + layer.ffn_gate_shexp, nullptr, layer.ffn_gate_shexp_s, + layer.ffn_down_shexp, nullptr, layer.ffn_down_shexp_s, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "mtp_ffn_shexp", il); + + ggml_tensor * shared_gate = build_lora_mm(layer.ffn_gate_inp_shexp, cur); + shared_gate = ggml_sigmoid(ctx0, shared_gate); + cb(shared_gate, "mtp_shared_expert_gate_sigmoid", il); + + ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate); + cb(ffn_shexp, "mtp_ffn_shexp_gated", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + } else { + cur = moe_out; + } + cb(cur, "mtp_ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_residual); + cb(cur, "mtp_post_ffn", il); + + // Pre-norm hidden state: used by the AR draft loop to seed the next MTP step. + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + + ggml_tensor * head_norm_w = layer.nextn.shared_head_norm + ? layer.nextn.shared_head_norm + : model.output_norm; + GGML_ASSERT(head_norm_w && "QWEN35MOE_MTP: missing both nextn.shared_head_norm and output_norm"); + cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1); + cb(cur, "mtp_shared_head_norm", -1); + + ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output; + GGML_ASSERT(head_w && "QWEN35MOE_MTP: missing LM head (nextn.shared_head_head or model.output)"); + cur = build_lora_mm(head_w, cur); + cb(cur, "result_output", -1); + + res->t_logits = cur; + ggml_build_forward_expand(gf, cur); +} diff --git a/src/models/qwen35moe.cpp b/src/models/qwen35moe.cpp index cf05dc9d61c..5912aa38153 100644 --- a/src/models/qwen35moe.cpp +++ b/src/models/qwen35moe.cpp @@ -15,16 +15,23 @@ void llama_model_qwen35moe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - // Mark recurrent layers (linear attention layers) + // NextN/MTP (Qwen3.5/3.6): extra decoder block appended beyond the main stack + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + + // Mark recurrent layers (linear attention layers). MTP layers are dense + // attention-only and must be flagged non-recurrent. { + const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; uint32_t full_attn_interval = 4; ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + hparams.recurrent_layer_arr[i] = (i < n_main) && ((i + 1) % full_attn_interval != 0); } } - switch (hparams.n_layer) { + switch (hparams.n_layer - hparams.nextn_predict_layers) { case 40: type = LLM_TYPE_35B_A3B; break; case 48: type = LLM_TYPE_122B_A10B; break; case 60: type = LLM_TYPE_397B_A17B; break; @@ -96,6 +103,16 @@ void llama_model_qwen35moe::load_arch_tensors(llama_model_loader &) { layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0); + + // NextN/MTP tensors (preserved but unused) - only bound on MTP layers + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, TENSOR_NOT_REQUIRED); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); + } } } @@ -124,7 +141,9 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_out_ids = build_inp_out_ids(); - for (int il = 0; il < n_layer; ++il) { + // MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass. + const int n_transformer_layers = n_layer - (int) hparams.nextn_predict_layers; + for (int il = 0; il < n_transformer_layers; ++il) { ggml_tensor * inpSA = inpL; cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); @@ -141,7 +160,7 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); } - if (il == n_layer - 1 && inp_out_ids) { + if (il == n_transformer_layers - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -173,6 +192,9 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p } cur = inpL; + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + // Final norm cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); diff --git a/tests/test-llama-archs.cpp b/tests/test-llama-archs.cpp index 16af11a2862..fd0d3696d77 100644 --- a/tests/test-llama-archs.cpp +++ b/tests/test-llama-archs.cpp @@ -406,6 +406,9 @@ static bool arch_supported(const llm_arch arch) { if (arch == LLM_ARCH_DEEPSEEK2OCR) { return false; } + if (arch == LLM_ARCH_QWEN35_MTP || arch == LLM_ARCH_QWEN35MOE_MTP) { + return false; // MTP-only arch; requires a sibling trunk model and cannot run standalone. + } // FIXME some models are segfaulting with WebGPU: #ifdef GGML_USE_WEBGPU diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index cc9550670a0..f7219e687ff 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -57,6 +57,11 @@ struct server_slot { llama_context * ctx_tgt = nullptr; llama_context * ctx_dft = nullptr; + // True when this slot's speculative impl is MTP (ctx_dft is the MTP head). + // MTP needs every prefill position to carry logits=1 so the streaming + // hook in common_speculative_state_mtp::process() can read t_h_pre_norm. + bool is_mtp_enabled = false; + // multimodal mtmd_context * mctx = nullptr; @@ -237,8 +242,20 @@ struct server_slot { (ggml_time_us() - t_start) / 1000.0, n_text, (int) prompt.tokens.size()); } + bool is_mtp() const { return is_mtp_enabled; } + + // The trunk needs to emit logits at every prefill position when either: + // - the task asked for embeddings, or + // - MTP is enabled for this slot (the streaming hook in process() reads + // h_pre_norm at every prompt position). + bool need_embd() const { + GGML_ASSERT(task); + return task->need_embd() || is_mtp(); + } + // if the context does not have a memory module then all embeddings have to be computed within a single ubatch // also we cannot split if the pooling would require any past tokens + // (MTP supports splitting — uses task->need_embd() not need_embd()) bool can_split() const { GGML_ASSERT(task); @@ -743,6 +760,48 @@ struct server_context_impl { ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get()); + params_base.speculative.draft.ctx_tgt = ctx_tgt; + params_base.speculative.draft.ctx_dft = ctx_dft.get(); + } else if (params_base.speculative.type == COMMON_SPECULATIVE_TYPE_MTP) { + // MTP head lives in the *target* GGUF — load it as a sibling model + // with override_arch and feed it through the existing ctx_dft slot. + char trunk_arch[64] = {0}; + llama_model_meta_val_str(model_tgt, "general.architecture", trunk_arch, sizeof(trunk_arch)); + + const char * mtp_arch = nullptr; + if (std::string(trunk_arch) == "qwen35") { + mtp_arch = "qwen35_mtp"; + } else if (std::string(trunk_arch) == "qwen35moe") { + mtp_arch = "qwen35moe_mtp"; + } else { + SRV_ERR("MTP not supported for trunk architecture '%s'\n", trunk_arch); + return false; + } + + SRV_INF("loading MTP head from '%s' (override_arch=%s)\n", + params_base.model.path.c_str(), mtp_arch); + + auto mparams_mtp = common_model_params_to_llama(params_base); + mparams_mtp.override_arch = mtp_arch; + + model_dft.reset(llama_model_load_from_file(params_base.model.path.c_str(), mparams_mtp)); + if (model_dft == nullptr) { + SRV_ERR("failed to load MTP head from '%s'\n", params_base.model.path.c_str()); + return false; + } + + auto cparams_mtp = common_context_params_to_llama(params_base); + cparams_mtp.n_ctx = llama_n_ctx_seq(ctx_tgt); + cparams_mtp.n_seq_max = 1; + + ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams_mtp)); + if (ctx_dft == nullptr) { + SRV_ERR("%s", "failed to create MTP context\n"); + return false; + } + + ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get()); + params_base.speculative.draft.ctx_tgt = ctx_tgt; params_base.speculative.draft.ctx_dft = ctx_dft.get(); } @@ -855,6 +914,7 @@ struct server_context_impl { slot.ctx_tgt = ctx_tgt; slot.ctx_dft = ctx_dft.get(); slot.spec = spec.get(); + slot.is_mtp_enabled = (params_base.speculative.type == COMMON_SPECULATIVE_TYPE_MTP) && (ctx_dft != nullptr); slot.n_ctx = n_ctx_slot; slot.mctx = mctx; @@ -2710,12 +2770,14 @@ struct server_context_impl { break; } - // embedding requires all tokens in the batch to be output + // embedding requires all tokens in the batch to be output; + // MTP also wants logits at every prompt position so the + // streaming hook can mirror t_h_pre_norm into ctx_dft. common_batch_add(batch, cur_tok, slot.prompt.tokens.pos_next(), { slot.id }, - slot.task->need_embd()); + slot.need_embd()); slot.prompt.tokens.push_back(cur_tok); slot.n_prompt_tokens_processed++; @@ -2832,7 +2894,7 @@ struct server_context_impl { slot_batched->lora[alora_disabled_id].scale = alora_scale; } - llama_set_embeddings(ctx_tgt, slot_batched->task->need_embd()); + llama_set_embeddings(ctx_tgt, slot_batched->need_embd()); } if (batch.n_tokens == 0) {