From 9b996f03fe3fb7f3d057071fd6a4981d06dbd0f6 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Mon, 11 May 2026 11:18:17 +0800 Subject: [PATCH 01/28] spec: support MTP --- common/common.h | 1 + common/speculative.cpp | 339 +++++++++++++++++++++++++++++++- conversion/qwen.py | 58 +++++- gguf-py/gguf/constants.py | 18 +- include/llama.h | 3 + src/llama-arch.cpp | 19 +- src/llama-arch.h | 2 + src/llama-context.cpp | 117 +++++++++-- src/llama-context.h | 9 + src/llama-cparams.h | 1 + src/llama-ext.h | 16 ++ src/llama-graph.h | 2 + src/llama-hparams.cpp | 6 + src/llama-hparams.h | 2 + src/llama-model-loader.cpp | 13 +- src/llama-model-loader.h | 2 +- src/llama-model.cpp | 19 +- src/models/models.h | 26 +++ src/models/qwen35.cpp | 32 ++- src/models/qwen35_mtp.cpp | 207 +++++++++++++++++++ src/models/qwen35moe.cpp | 32 ++- src/models/qwen35moe_mtp.cpp | 252 ++++++++++++++++++++++++ tests/test-llama-archs.cpp | 3 + tools/server/server-context.cpp | 73 ++++++- 24 files changed, 1206 insertions(+), 46 deletions(-) create mode 100644 src/models/qwen35_mtp.cpp create mode 100644 src/models/qwen35moe_mtp.cpp diff --git a/common/common.h b/common/common.h index c6223c4b515..37dab50ba5c 100644 --- a/common/common.h +++ b/common/common.h @@ -159,6 +159,7 @@ enum common_speculative_type { COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE, // standalone draft model speculative decoding COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, // Eagle3 speculative decoding + COMMON_SPECULATIVE_TYPE_MTP, // multi-token prediction head loaded from the target GGUF COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding based on n-grams COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values diff --git a/common/speculative.cpp b/common/speculative.cpp index 476e1398ed8..b6530253ef5 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -3,6 +3,7 @@ #include "common.h" #include "ggml.h" #include "llama.h" +#include "../src/llama-ext.h" // staging API: llama_set_embeddings_pre_norm / llama_get_embeddings_pre_norm_ith (used by MTP) #include "log.h" #include "ngram-cache.h" #include "ngram-map.h" @@ -23,6 +24,7 @@ const std::map common_speculative_type_fro {"none", COMMON_SPECULATIVE_TYPE_NONE}, {"draft-simple", COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE}, {"draft-eagle3", COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3}, + {"mtp", COMMON_SPECULATIVE_TYPE_MTP}, {"ngram-simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE}, {"ngram-map-k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K}, {"ngram-map-k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V}, @@ -364,6 +366,330 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl { } }; +struct common_speculative_state_mtp : public common_speculative_impl { + common_params_speculative_draft params; // reuses the draft-model params slot (ctx_tgt/ctx_dft) + + llama_batch batch; + + std::vector smpls; + + int32_t n_embd = 0; + + // Per-sequence cross-batch carryover: pair (h_p, x_{p+1}) at MTP pos p+1. + // The last h-row of one process() call needs the first token of the NEXT + // call to pair with, so it's stashed here until that next call fires. + std::vector> pending_h; // [n_seq][n_embd] + std::vector pending_pos; // [n_seq] + + std::vector last_n_drafted; + std::vector last_n_accepted; + + // Number of trunk output rows produced by the most recent process() call. + // Used by draft() for the first AR step (when last_n_accepted is -1) to + // pick the last prefill row out of ctx_tgt's pre-norm buffer. + std::vector last_trunk_n_outputs; + + common_speculative_state_mtp(const common_params_speculative & params, uint32_t n_seq) + : common_speculative_impl(COMMON_SPECULATIVE_TYPE_MTP, n_seq) + , params(params.draft) + { + GGML_ASSERT(n_seq == 1 && "MTP currently supports only single-sequence speculation"); + + auto * ctx_tgt = this->params.ctx_tgt; + auto * ctx_dft = this->params.ctx_dft; + GGML_ASSERT(ctx_tgt && ctx_dft && "MTP requires ctx_tgt and ctx_dft to be set"); + + n_embd = llama_model_n_embd(llama_get_model(ctx_dft)); + + const int32_t n_ub = (int32_t) llama_n_ubatch(ctx_dft); + batch = llama_batch_init(/*n_tokens=*/ n_ub, /*embd=*/ n_embd, /*n_seq_max=*/ 1); + // llama_batch_init allocates only one of token/embd; MTP needs both. + // TODO: fix, how to call without malloc + batch.token = (llama_token *) malloc(sizeof(llama_token) * n_ub); + + smpls.resize(n_seq); + for (auto & s : smpls) { + common_params_sampling sparams; + sparams.no_perf = false; + sparams.top_k = 1; + sparams.samplers = { COMMON_SAMPLER_TYPE_TOP_K }; + s.reset(common_sampler_init(llama_get_model(ctx_dft), sparams)); + } + + llama_set_embeddings_pre_norm(ctx_tgt, true); + llama_set_embeddings_pre_norm(ctx_dft, true); + + pending_h.assign(n_seq, std::vector(n_embd, 0.0f)); + pending_pos.assign(n_seq, -1); + + last_n_drafted.assign(n_seq, 0); + last_n_accepted.assign(n_seq, -1); + last_trunk_n_outputs.assign(n_seq, 0); + } + + ~common_speculative_state_mtp() override { + if (batch.token != nullptr) { + free(batch.token); + batch.token = nullptr; + } + llama_batch_free(batch); + } + + void begin(llama_seq_id seq_id, const llama_tokens & prompt) override { + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < pending_pos.size()); + + last_n_accepted[seq_id] = -1; + last_n_drafted [seq_id] = 0; + pending_pos [seq_id] = -1; + + const int32_t N = (int32_t) prompt.size(); + if (N <= 0) { + return; + } + auto * ctx_dft = this->params.ctx_dft; + const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id); + if (pos_max < N - 1) { + LOG_WRN("%s: ctx_dft pos_max=%d < N-1=%d — " + "process() hook may not have run on every prefill ubatch " + "(need_embd / logits=1 on every prompt position?). " + "Drafts may degrade.\n", + __func__, (int) pos_max, N - 1); + } + } + + bool process(const llama_batch & batch_in) override { + if (batch_in.n_tokens <= 0) { + return true; + } + + // Single-seq for now (asserted in ctor). Future: bucket by seq_id. + const llama_seq_id seq_id = 0; + + // TODO: how to make it work with vision tokens? + if (batch_in.token == nullptr || batch_in.embd != nullptr) { + pending_pos[seq_id] = -1; + return true; + } + + auto * ctx_tgt = this->params.ctx_tgt; + auto * ctx_dft = this->params.ctx_dft; + + const int32_t n_rows = batch_in.n_tokens; + const llama_pos pos_start = batch_in.pos[0]; + + const llama_pos pos_max_dft = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id); + if (pos_start <= pos_max_dft) { + return true; + } + + // Stale pending: discard if the new batch doesn't start one past it. + const bool pending_continues = + pending_pos[seq_id] >= 0 && pending_pos[seq_id] + 1 == pos_start; + if (pending_pos[seq_id] >= 0 && !pending_continues) { + pending_pos[seq_id] = -1; + } + + // Build a paired hook batch: + // row 0 = (pending_h, batch_in.token[0]) at pos_start if pending_continues + // rows 1..n_rows-1 = (h_k from this batch, batch_in.token[k+1]) at pos[k+1] + // The last h-row (h_{n_rows-1}) is stashed as the new pending and is *not* + // decoded this call — it waits for the next batch's first token to pair. + const size_t row_bytes = (size_t) n_embd * sizeof(float); + + common_batch_clear(batch); + int out_idx = 0; + + auto add_pair = [&](const float * h_row, llama_token tok, llama_pos pos) { + std::memcpy(batch.embd + (size_t) out_idx * n_embd, h_row, row_bytes); + batch.token [out_idx] = tok; + batch.pos [out_idx] = pos; + batch.n_seq_id[out_idx] = 1; + batch.seq_id [out_idx][0] = seq_id; + batch.logits [out_idx] = 0; + ++out_idx; + }; + + if (pending_continues) { + add_pair(pending_h[seq_id].data(), batch_in.token[0], pos_start); + } + + // TODO: is there is a fast way to build this batch? + for (int k = 0; k + 1 < n_rows; ++k) { + if (batch_in.logits[k] == 0) { + LOG_WRN("%s: batch_in.logits[%d] == 0 (need_embd / logits=1 missing on prefill); stopping hook at this row\n", + __func__, k); + break; + } + const float * h_k = llama_get_embeddings_pre_norm_ith(ctx_tgt, k); + if (h_k == nullptr) { + LOG_WRN("%s: ctx_tgt has no pre-norm row at i=%d; stopping hook\n", __func__, k); + break; + } + add_pair(h_k, batch_in.token[k + 1], batch_in.pos[k + 1]); + } + + if (out_idx > 0) { + batch.n_tokens = out_idx; + const int32_t rc = llama_decode(ctx_dft, batch); + if (rc != 0) { + LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d, n=%d)\n", + __func__, (int) rc, (int) pos_start, out_idx); + return false; + } + } + + // last_n_accepted < 0) can find the last pre-norm row of this batch. + // We assume every batch position has logits=1 (server sets need_embd + // for MTP slots) → n_outputs == n_tokens. + last_trunk_n_outputs[seq_id] = n_rows; + + // Stash the last h-row (h_{n_rows-1}) as the new pending for the next + // process() call's first token to pair with. + if (batch_in.logits[n_rows - 1] != 0) { + const float * h_last = llama_get_embeddings_pre_norm_ith(ctx_tgt, n_rows - 1); + if (h_last != nullptr) { + std::memcpy(pending_h[seq_id].data(), h_last, row_bytes); + pending_pos[seq_id] = batch_in.pos[n_rows - 1]; + } else { + pending_pos[seq_id] = -1; + } + } else { + // No trunk output at the tail — can't carry over. + pending_pos[seq_id] = -1; + } + + return true; + } + + void draft(common_speculative_draft_params_vec & dparams) override { + // Single-seq for now (asserted in ctor). Future: iterate over dparams. + const llama_seq_id seq_id = 0; + if ((size_t) seq_id >= dparams.size()) { + return; + } + auto & dp = dparams[seq_id]; + if (!dp.drafting) { + return; + } + + auto * ctx_tgt = this->params.ctx_tgt; + auto * ctx_dft = this->params.ctx_dft; + auto * smpl = smpls[seq_id].get(); + + GGML_ASSERT(dp.result != nullptr); + auto & draft_tokens = *dp.result; + draft_tokens.clear(); + + if (last_n_drafted[seq_id] > 0) { + const int32_t n_to_drop = (int32_t) last_n_drafted[seq_id] - 1; + if (n_to_drop > 0) { + const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id); + if (pos_max >= 0) { + const llama_pos drop_from = pos_max - n_to_drop + 1; + llama_memory_seq_rm(llama_get_memory(ctx_dft), seq_id, drop_from, -1); + } + } + last_n_drafted[seq_id] = 0; + last_n_accepted[seq_id] = 0; + } + + // Effective draft length: min(global cap, per-sequence override). + int32_t n_max = std::max(1, params.n_max); + if (dp.n_max > 0) { + n_max = std::min(n_max, dp.n_max); + } + + const size_t row_bytes = (size_t) n_embd * sizeof(float); + + common_sampler_reset(smpl); + + llama_token cond_tok = dp.id_last; + llama_pos pos = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id) + 1; + + for (int32_t k = 0; k < n_max; ++k) { + const float * h_row = nullptr; + + if (k == 0) { + // Condition on the trunk's pre-norm row. + int32_t row_idx; + if (last_n_accepted[seq_id] < 0) { + // First draft after begin(): use the last prefill row. + row_idx = last_trunk_n_outputs[seq_id] - 1; + } else { + // After accept(n_accepted): row of the next conditioning + // position in the trunk's verify batch. + row_idx = last_n_accepted[seq_id]; + } + if (row_idx < 0) { + LOG_WRN("%s: no trunk pre-norm row available (row_idx=%d); stopping chain\n", + __func__, row_idx); + break; + } + h_row = llama_get_embeddings_pre_norm_ith(ctx_tgt, row_idx); + } else { + // AR step: condition on the MTP head's own pre-norm row from + // the just-completed single-token decode. n_outputs=1 there, + // so the row is at batch position 0. + h_row = llama_get_embeddings_pre_norm_ith(ctx_dft, 0); + } + + if (h_row == nullptr) { + LOG_WRN("%s: missing pre-norm row at k=%d; stopping chain\n", __func__, k); + break; + } + + // 1-token batch carrying both (token, h_pre_norm). + common_batch_clear(batch); + std::memcpy(batch.embd, h_row, row_bytes); + batch.token [0] = cond_tok; + batch.pos [0] = pos; + batch.n_seq_id[0] = 1; + batch.seq_id [0][0] = seq_id; + batch.logits [0] = 1; // need logits for sampling + batch.n_tokens = 1; + + const int32_t rc = llama_decode(ctx_dft, batch); + if (rc != 0) { + LOG_WRN("%s: llama_decode(ctx_dft) failed rc=%d at k=%d; stopping chain\n", + __func__, rc, k); + break; + } + + const llama_token best = common_sampler_sample(smpl, ctx_dft, 0); + common_sampler_accept(smpl, best, /*is_generated=*/ false); + draft_tokens.push_back(best); + cond_tok = best; + ++pos; + } + + last_n_drafted[seq_id] = (uint16_t) draft_tokens.size(); + } + + void accept(llama_seq_id seq_id, uint16_t n_accepted) override { + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < last_n_drafted.size()); + + auto * ctx_dft = this->params.ctx_dft; + + const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id); + const int32_t n_drafted_last = (int32_t) last_n_drafted[seq_id]; + + const int32_t n_to_drop = std::max(0, n_drafted_last - (int32_t) n_accepted - 1); + + if (pos_max < 0) { + last_n_accepted[seq_id] = (int32_t) n_accepted; + return; + } + + if (n_to_drop > 0) { + const llama_pos drop_from = pos_max - n_to_drop + 1; + llama_memory_seq_rm(llama_get_memory(ctx_dft), seq_id, drop_from, -1); + } + + last_n_drafted [seq_id] = 0; + last_n_accepted[seq_id] = (int32_t) n_accepted; + } +}; + // state of self-speculation (simple implementation, not ngram-map) struct common_speculative_impl_ngram_simple : public common_speculative_impl { common_params_speculative_ngram_map params; @@ -820,6 +1146,7 @@ std::string common_speculative_type_to_str(common_speculative_type type) { case COMMON_SPECULATIVE_TYPE_NONE: return "none"; case COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE: return "draft-simple"; case COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3: return "draft-eagle3"; + case COMMON_SPECULATIVE_TYPE_MTP: return "mtp"; case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram-simple"; case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram-map-k"; case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram-map-k4v"; @@ -875,8 +1202,8 @@ common_speculative * common_speculative_init(common_params_speculative & params, bool has_draft_model_path = !params.draft.mparams.path.empty(); bool has_draft_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE)); - // bool has_mtp = false; // TODO: add MTP here bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3 + bool has_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_MTP)) && params.draft.ctx_dft != nullptr; bool has_ngram_cache = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_CACHE)); bool has_ngram_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE)); @@ -885,7 +1212,7 @@ common_speculative * common_speculative_init(common_params_speculative & params, bool has_ngram_mod = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_MOD)); // when adding a new type - update here the logic above - static_assert(COMMON_SPECULATIVE_TYPE_COUNT == 8); + static_assert(COMMON_SPECULATIVE_TYPE_COUNT == 9); // this list here defines the priority of the speculators // the one with highest priority are listed first @@ -919,10 +1246,12 @@ common_speculative * common_speculative_init(common_params_speculative & params, if (has_draft_simple) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE, params)); } - // TODO: add MTP here if (has_draft_eagle3) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, params)); } + if (has_mtp) { + configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_MTP, params)); + } } std::vector> impls = {}; @@ -940,6 +1269,10 @@ common_speculative * common_speculative_init(common_params_speculative & params, impls.push_back(std::make_unique(config.params, n_seq)); break; } + case COMMON_SPECULATIVE_TYPE_MTP: { + impls.push_back(std::make_unique(config.params, n_seq)); + break; + } case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: { common_ngram_map ngram_map = get_common_ngram_map(config.type, config.params.ngram_simple); diff --git a/conversion/qwen.py b/conversion/qwen.py index 919ecddcb91..e8b49a22bdb 100644 --- a/conversion/qwen.py +++ b/conversion/qwen.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Callable, Iterable, TYPE_CHECKING +from pathlib import Path +from typing import Any, Callable, Iterable, TYPE_CHECKING import torch @@ -534,11 +535,62 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_dimension_sections(self._QWEN35_DEFAULT_MROPE_SECTION) +class _Qwen35MtpMixin: + """Shared MTP wiring for Qwen3.5/3.6 text variants. The HF config carries + the MTP block under `mtp_num_hidden_layers` and the tensors under + `mtp.*`; we extend block_count, emit the nextn metadata key, and remap + `mtp.*` to the standard layer-indexed nextn naming so the existing + tensor_map handles them.""" + + hparams: dict[str, Any] + model_arch: gguf.MODEL_ARCH + gguf_writer: gguf.GGUFWriter + block_count: int + tensor_map: gguf.TensorNameMap + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.block_count = self.hparams["num_hidden_layers"] + self.hparams.get("mtp_num_hidden_layers", 0) + self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + + def set_gguf_parameters(self): + super().set_gguf_parameters() # ty: ignore[unresolved-attribute] + if (n := self.hparams.get("mtp_num_hidden_layers", 0)) > 0: + self.gguf_writer.add_nextn_predict_layers(n) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if name.startswith("model.language_model."): + name = "model." + name[len("model.language_model."):] + elif name.startswith("language_model."): + name = name[len("language_model."):] + + if name.startswith("mtp."): + n_layer = self.hparams["num_hidden_layers"] + if name.find("layers.") != -1: + assert bid is not None + name = name.replace(f"mtp.layers.{bid}", f"model.layers.{bid + n_layer}") + else: + remapper = { + "mtp.fc": "model.layers.{bid}.eh_proj", + "mtp.pre_fc_norm_embedding": "model.layers.{bid}.enorm", + "mtp.pre_fc_norm_hidden": "model.layers.{bid}.hnorm", + "mtp.norm": "model.layers.{bid}.shared_head.norm", + } + stem = Path(name).stem + suffix = Path(name).suffix + tmpl = remapper[stem] + suffix + for b in range(n_layer, self.block_count): + yield from super().modify_tensors(data_torch, tmpl.format(bid=b), b) # ty: ignore[unresolved-attribute] + return + + yield from super().modify_tensors(data_torch, name, bid) # ty: ignore[unresolved-attribute] + + @ModelBase.register("Qwen3_5ForConditionalGeneration", "Qwen3_5ForCausalLM") -class Qwen3_5TextModel(_Qwen35MRopeMixin, _LinearAttentionVReorderBase): +class Qwen3_5TextModel(_Qwen35MtpMixin, _Qwen35MRopeMixin, _LinearAttentionVReorderBase): model_arch = gguf.MODEL_ARCH.QWEN35 @ModelBase.register("Qwen3_5MoeForConditionalGeneration", "Qwen3_5MoeForCausalLM") -class Qwen3_5MoeTextModel(_Qwen35MRopeMixin, _LinearAttentionVReorderBase): +class Qwen3_5MoeTextModel(_Qwen35MtpMixin, _Qwen35MRopeMixin, _LinearAttentionVReorderBase): model_arch = gguf.MODEL_ARCH.QWEN35MOE diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 4055ec2873a..c25f217f990 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -2114,7 +2114,14 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_NORM, MODEL_TENSOR.SSM_BETA, MODEL_TENSOR.SSM_ALPHA, - MODEL_TENSOR.SSM_OUT + MODEL_TENSOR.SSM_OUT, + # NextN/MTP tensors - preserved but unused + MODEL_TENSOR.NEXTN_EH_PROJ, + MODEL_TENSOR.NEXTN_EMBED_TOKENS, + MODEL_TENSOR.NEXTN_ENORM, + MODEL_TENSOR.NEXTN_HNORM, + MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD, + MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM, ], MODEL_ARCH.QWEN35MOE: [ MODEL_TENSOR.TOKEN_EMBD, @@ -2145,7 +2152,14 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_NORM, MODEL_TENSOR.SSM_BETA, MODEL_TENSOR.SSM_ALPHA, - MODEL_TENSOR.SSM_OUT + MODEL_TENSOR.SSM_OUT, + # NextN/MTP tensors - preserved but unused + MODEL_TENSOR.NEXTN_EH_PROJ, + MODEL_TENSOR.NEXTN_EMBED_TOKENS, + MODEL_TENSOR.NEXTN_ENORM, + MODEL_TENSOR.NEXTN_HNORM, + MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD, + MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM, ], MODEL_ARCH.PLAMO: [ MODEL_TENSOR.TOKEN_EMBD, diff --git a/include/llama.h b/include/llama.h index 308e8ba9dbd..1b896944735 100644 --- a/include/llama.h +++ b/include/llama.h @@ -310,6 +310,9 @@ extern "C" { // override key-value pairs of the model meta data const struct llama_model_kv_override * kv_overrides; + // override architecture from GGUF (e.g. load the MTP head of a Qwen3.5 GGUF as "qwen35_mtp") + const char * override_arch; + // Keep the booleans together to avoid misalignment during copy-by-value. bool vocab_only; // only load the vocabulary, no weights bool use_mmap; // use mmap if possible diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 59dde99e362..794666d09a4 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -41,6 +41,8 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" }, { LLM_ARCH_QWEN35, "qwen35" }, { LLM_ARCH_QWEN35MOE, "qwen35moe" }, + { LLM_ARCH_QWEN35_MTP, "qwen35_mtp" }, + { LLM_ARCH_QWEN35MOE_MTP, "qwen35moe_mtp" }, { LLM_ARCH_PHI2, "phi2" }, { LLM_ARCH_PHI3, "phi3" }, { LLM_ARCH_PHIMOE, "phimoe" }, @@ -757,14 +759,15 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - // NextN/MTP tensors are currently ignored (reserved for future MTP support) - // These tensors only exist in the last layer(s) and are treated as output tensors - {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, - {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + // NextN/MTP tensors are stored per-block (blk.%d.nextn.*) even though only the + // last nextn_predict_layers blocks carry them. Classify as LAYER_REPEATING so + // the model loader doesn't fault on the block index. + {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // Nemotron 3 Super {LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, diff --git a/src/llama-arch.h b/src/llama-arch.h index e37d548c98e..71c2ca6e6b3 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -45,6 +45,8 @@ enum llm_arch { LLM_ARCH_QWEN3VLMOE, LLM_ARCH_QWEN35, LLM_ARCH_QWEN35MOE, + LLM_ARCH_QWEN35_MTP, + LLM_ARCH_QWEN35MOE_MTP, LLM_ARCH_PHI2, LLM_ARCH_PHI3, LLM_ARCH_PHIMOE, diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 3d9714ab166..aea8a0a4e81 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -49,6 +49,7 @@ llama_context::llama_context( cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast; cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow; cparams.embeddings = params.embeddings; + cparams.embeddings_pre_norm = false; cparams.offload_kqv = params.offload_kqv; cparams.no_perf = params.no_perf; cparams.pooling_type = params.pooling_type; @@ -860,6 +861,33 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) { return it->second.data(); } +float * llama_context::get_embeddings_pre_norm() { + output_reorder(); + + return embd_pre_norm.data; +} + +float * llama_context::get_embeddings_pre_norm_ith(int32_t i) { + output_reorder(); + + try { + if (embd_pre_norm.data == nullptr) { + throw std::runtime_error("no pre-norm embeddings"); + } + + const int64_t j = output_resolve_row(i); + const uint32_t n_embd = model.hparams.n_embd; + return embd_pre_norm.data + j*n_embd; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid pre-norm embeddings id %d, reason: %s\n", __func__, i, err.what()); +#ifndef NDEBUG + GGML_ABORT("fatal error"); +#else + return nullptr; +#endif + } +} + llama_token llama_context::get_sampled_token_ith(int32_t idx) { output_reorder(); @@ -1040,6 +1068,12 @@ void llama_context::set_embeddings(bool value) { //sched_need_reserve = true; } +void llama_context::set_embeddings_pre_norm(bool value) { + LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); + + cparams.embeddings_pre_norm = value; +} + void llama_context::set_causal_attn(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); @@ -1241,7 +1275,9 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll } int llama_context::encode(const llama_batch & batch_inp) { - GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT + // MTP hook batches carry both token (next-token id) and embd (h_pre_norm row), + // so accept either present rather than requiring exactly one. + GGML_ASSERT(batch_inp.token || batch_inp.embd); if (batch_inp.n_tokens == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); @@ -1312,8 +1348,9 @@ int llama_context::encode(const llama_batch & batch_inp) { } } - auto * t_logits = res->get_logits(); - auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); + auto * t_logits = res->get_logits(); + auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); + auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm() : nullptr; // extract logits if (logits.data && t_logits) { @@ -1379,6 +1416,16 @@ int llama_context::encode(const llama_batch & batch_inp) { } } + // extract pre-norm embeddings (hidden state before the final output norm) + if (embd_pre_norm.data && t_h_pre_norm && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { + ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm); + GGML_ASSERT(backend_h != nullptr); + + const uint32_t n_embd = hparams.n_embd; + GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_pre_norm.size); + ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm.data, 0, n_tokens*n_embd*sizeof(float)); + } + // TODO: hacky solution if (model.arch == LLM_ARCH_T5 && t_embd) { //cross.t_embd = t_embd; @@ -1531,7 +1578,9 @@ static bool needs_raw_logits(const llama_ubatch & ubatch, const std::mapget_logits(); - auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; + auto * t_logits = res->get_logits(); + auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; + auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm() : nullptr; if (t_embd && res->get_embd_pooled()) { t_embd = res->get_embd_pooled(); @@ -1809,6 +1859,20 @@ int llama_context::decode(const llama_batch & batch_inp) { } } + // extract pre-norm embeddings (hidden state before the final output norm) + // only meaningful in LLAMA_POOLING_TYPE_NONE (per-token); other pooling modes are ignored. + if (embd_pre_norm.data && t_h_pre_norm && n_outputs > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { + ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm); + GGML_ASSERT(backend_h != nullptr); + + const uint32_t n_embd = hparams.n_embd; + float * embd_pre_norm_out = embd_pre_norm.data + n_outputs_prev*n_embd; + + GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_pre_norm.size); + ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm_out, 0, n_outputs*n_embd*sizeof(float)); + } + // Copy backend sampling output if this ubatch produced any sampling tensors. if (has_samplers && (!res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty())) { const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev); @@ -1893,10 +1957,12 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { const auto n_batch = cparams.n_batch; const auto n_vocab = vocab.n_tokens(); + const auto n_embd = hparams.n_embd; const auto n_embd_out = hparams.n_embd_out(); - bool has_logits = true; - bool has_embd = cparams.embeddings; + bool has_logits = true; + bool has_embd = cparams.embeddings; + bool has_embd_pre_norm = cparams.embeddings_pre_norm; // TODO: hacky enc-dec support if (model.arch == LLM_ARCH_T5) { @@ -1908,8 +1974,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { size_t backend_float_count = 0; size_t backend_token_count = 0; - logits.size = has_logits ? n_vocab*n_outputs_max : 0; - embd.size = has_embd ? n_embd_out*n_outputs_max : 0; + logits.size = has_logits ? n_vocab*n_outputs_max : 0; + embd.size = has_embd ? n_embd_out*n_outputs_max : 0; + embd_pre_norm.size = has_embd_pre_norm ? n_embd*n_outputs_max : 0; // Allocate backend sampling output buffers if there are backend samplers configured. const bool has_sampling = !sampling.samplers.empty(); @@ -1925,8 +1992,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0; const size_t new_size = - (logits.size + embd.size + backend_float_count) * sizeof(float) + - ( backend_token_count) * sizeof(llama_token); + (logits.size + embd.size + embd_pre_norm.size + backend_float_count) * sizeof(float) + + ( backend_token_count) * sizeof(llama_token); // alloc only when more than the current capacity is required // TODO: also consider shrinking the buffer @@ -1942,6 +2009,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { buf_output = nullptr; logits.data = nullptr; embd.data = nullptr; + embd_pre_norm.data = nullptr; } auto * buft = ggml_backend_cpu_buffer_type(); @@ -1970,6 +2038,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { embd = has_embd ? buffer_view{(float *) (base + offset), embd.size} : buffer_view{nullptr, 0}; offset += embd.size * sizeof(float); + embd_pre_norm = has_embd_pre_norm ? buffer_view{(float *) (base + offset), embd_pre_norm.size} : buffer_view{nullptr, 0}; + offset += embd_pre_norm.size * sizeof(float); + if (has_sampling) { sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; offset += sampling.logits.size * sizeof(float); @@ -2034,6 +2105,12 @@ void llama_context::output_reorder() { } } + if (embd_pre_norm.size > 0) { + for (uint64_t k = 0; k < n_embd; k++) { + std::swap(embd_pre_norm.data[i0*n_embd + k], embd_pre_norm.data[i1*n_embd + k]); + } + } + if (!sampling.samplers.empty()) { assert(sampling.logits.size > 0); assert(sampling.probs.size > 0); @@ -3436,6 +3513,22 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) { return ctx->get_embeddings_seq(seq_id); } +void llama_set_embeddings_pre_norm(llama_context * ctx, bool value) { + ctx->set_embeddings_pre_norm(value); +} + +float * llama_get_embeddings_pre_norm(llama_context * ctx) { + ctx->synchronize(); + + return ctx->get_embeddings_pre_norm(); +} + +float * llama_get_embeddings_pre_norm_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return ctx->get_embeddings_pre_norm_ith(i); +} + bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) { return ctx->set_sampler(seq_id, smpl); } diff --git a/src/llama-context.h b/src/llama-context.h index 92d1b0cf95a..e16ac4c618b 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -84,6 +84,9 @@ struct llama_context { float * get_embeddings_ith(int32_t i); float * get_embeddings_seq(llama_seq_id seq_id); + float * get_embeddings_pre_norm(); + float * get_embeddings_pre_norm_ith(int32_t i); + llama_token * get_sampled_tokens() const; llama_token get_sampled_token_ith(int32_t idx); @@ -107,6 +110,7 @@ struct llama_context { void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data); void set_embeddings (bool value); + void set_embeddings_pre_norm(bool value); void set_causal_attn(bool value); void set_warmup(bool value); @@ -278,6 +282,11 @@ struct llama_context { // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE buffer_view embd = {nullptr, 0}; + // hidden state before the final output norm (2-dimensional array: [n_outputs][n_embd]) + // populated only when cparams.embeddings_pre_norm is enabled and the model graph + // sets llm_graph_result::t_h_pre_norm + buffer_view embd_pre_norm = {nullptr, 0}; + struct sampling_info { // !samplers.empty() to check if any samplers are active std::map samplers; diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 9d359474132..1e4e9e29ed8 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -27,6 +27,7 @@ struct llama_cparams { float yarn_beta_slow; bool embeddings; + bool embeddings_pre_norm; // also extract the hidden state before the final output norm bool causal_attn; bool offload_kqv; bool flash_attn; diff --git a/src/llama-ext.h b/src/llama-ext.h index 8ce29d217cb..11f1986676a 100644 --- a/src/llama-ext.h +++ b/src/llama-ext.h @@ -88,3 +88,19 @@ LLAMA_API int32_t llama_model_n_devices(const struct llama_model * model); LLAMA_API ggml_backend_dev_t llama_model_get_device(const struct llama_model * model, int i); LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_context * ctx); + +// +// pre-norm embeddings (hidden state before the final output norm) +// + +// mirrors: +// LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings); +LLAMA_API void llama_set_embeddings_pre_norm(struct llama_context * ctx, bool value); + +// mirrors: +// LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); +LLAMA_API float * llama_get_embeddings_pre_norm(struct llama_context * ctx); + +// mirrors: +// LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); +LLAMA_API float * llama_get_embeddings_pre_norm_ith(struct llama_context * ctx, int32_t i); diff --git a/src/llama-graph.h b/src/llama-graph.h index 5cb1756c6a9..d3cd69a674c 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -644,6 +644,7 @@ class llm_graph_result { ggml_tensor * get_logits() const { return t_logits; } ggml_tensor * get_embd() const { return t_embd; } ggml_tensor * get_embd_pooled() const { return t_embd_pooled; } + ggml_tensor * get_h_pre_norm() const { return t_h_pre_norm; } ggml_cgraph * get_gf() const { return gf; } ggml_context * get_ctx() const { return ctx_compute.get(); } @@ -672,6 +673,7 @@ class llm_graph_result { ggml_tensor * t_logits = nullptr; ggml_tensor * t_embd = nullptr; ggml_tensor * t_embd_pooled = nullptr; + ggml_tensor * t_h_pre_norm = nullptr; // [n_embd, n_outputs] hidden state before final output norm std::map t_sampled_logits; std::map t_candidates; diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 002d15d415f..2239309c8fb 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -229,6 +229,12 @@ uint32_t llama_hparams::n_embd_head_v_mla() const { } bool llama_hparams::has_kv(uint32_t il) const { + if (kv_only_nextn) { + // MTP head: only the trailing nextn_predict_layers blocks own a KV cache; + // the leading trunk blocks are not executed in this graph. + return nextn_predict_layers > 0 && il >= (n_layer - nextn_predict_layers); + } + if (n_layer_kv_from_start >= 0) { if (il < (uint32_t) n_layer_kv_from_start) { return true; diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 0160a89caa2..e2d051edc6c 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -92,6 +92,8 @@ struct llama_hparams { uint32_t moe_latent_size = 0; uint32_t nextn_predict_layers = 0; + bool kv_only_nextn = false; // if true, only the last nextn_predict_layers blocks have a KV cache (MTP head arches) + float f_norm_eps; float f_norm_rms_eps; float f_norm_group_eps; diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index 4e65a45a50d..c645d0785ab 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -1312,9 +1312,16 @@ struct ggml_tensor * llama_model_loader::create_tensor_as_view(struct ggml_conte return tensor; } -void llama_model_loader::done_getting_tensors() const { - if (n_created != n_tensors) { - throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created)); +void llama_model_loader::done_getting_tensors(bool partial) const { + if (n_created > n_tensors) { + throw std::runtime_error(format("%s: too many tensors created; expected %d, got %d", __func__, n_tensors, n_created)); + } + if (n_created < n_tensors) { + if (!partial) { + throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created)); + } + LLAMA_LOG_INFO("%s: partial load — used %d of %d tensors in the file (rest belong to a sibling model on the same .gguf)\n", + __func__, n_created, n_tensors); } if (n_tensors_moved > 0) { LLAMA_LOG_DEBUG("%s: tensor '%s' (%s) (and %zu others) cannot be used with preferred buffer type %s, using %s instead\n", diff --git a/src/llama-model-loader.h b/src/llama-model-loader.h index 7b3d6703c03..c476026d3e5 100644 --- a/src/llama-model-loader.h +++ b/src/llama-model-loader.h @@ -184,7 +184,7 @@ struct llama_model_loader { struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::initializer_list & ne, size_t offset, bool required = true); - void done_getting_tensors() const; + void done_getting_tensors(bool partial = false) const; void init_mappings(bool prefetch = true, llama_mlocks * mlock_mmaps = nullptr); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 46ae010f800..e14f375521e 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -276,6 +276,10 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params return new llama_model_qwen35(params); case LLM_ARCH_QWEN35MOE: return new llama_model_qwen35moe(params); + case LLM_ARCH_QWEN35_MTP: + return new llama_model_qwen35_mtp(params); + case LLM_ARCH_QWEN35MOE_MTP: + return new llama_model_qwen35moe_mtp(params); case LLM_ARCH_MISTRAL3: return new llama_model_mistral3(params); case LLM_ARCH_MIMO2: @@ -309,6 +313,15 @@ llama_model * llama_model_create(llama_model_loader & ml, const llama_model_para if (arch == LLM_ARCH_UNKNOWN) { throw std::runtime_error("unknown model architecture: '" + ml.get_arch_name() + "'"); } + if (params.override_arch != nullptr && params.override_arch[0] != '\0') { + const llm_arch override = llm_arch_from_string(params.override_arch); + if (override == LLM_ARCH_UNKNOWN) { + throw std::runtime_error(std::string("unknown override architecture: '") + params.override_arch + "'"); + } + LLAMA_LOG_INFO("%s: overriding architecture %s -> %s\n", + __func__, llm_arch_name(arch), llm_arch_name(override)); + arch = override; + } return llama_model_create(arch, params); } @@ -1406,7 +1419,8 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) { } } } - ml.done_getting_tensors(); + const bool partial_load = (arch == LLM_ARCH_QWEN35_MTP || arch == LLM_ARCH_QWEN35MOE_MTP); + ml.done_getting_tensors(partial_load); GGML_ASSERT(!(output && tok_embd && strcmp(output->name, tok_embd->name) == 0 && @@ -2102,6 +2116,7 @@ llama_model_params llama_model_default_params() { /*.progress_callback =*/ nullptr, /*.progress_callback_user_data =*/ nullptr, /*.kv_overrides =*/ nullptr, + /*.override_arch =*/ nullptr, /*.vocab_only =*/ false, /*.use_mmap =*/ true, /*.use_direct_io =*/ false, @@ -2326,6 +2341,8 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_QWEN3VLMOE: case LLM_ARCH_QWEN35: case LLM_ARCH_QWEN35MOE: + case LLM_ARCH_QWEN35_MTP: + case LLM_ARCH_QWEN35MOE_MTP: return LLAMA_ROPE_TYPE_IMROPE; case LLM_ARCH_GLM4: diff --git a/src/models/models.h b/src/models/models.h index 6d5f18a8e20..1f04d313d13 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -1785,6 +1785,32 @@ struct llama_model_qwen35moe : public llama_model_base { }; +struct llama_model_qwen35_mtp : public llama_model_base { + llama_model_qwen35_mtp(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_qwen35moe_mtp : public llama_model_base { + llama_model_qwen35moe_mtp(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; + + struct llama_model_mistral3 : public llama_model_base { llama_model_mistral3(const struct llama_model_params & params) : llama_model_base(params) {} void load_arch_hparams(llama_model_loader & ml) override; diff --git a/src/models/qwen35.cpp b/src/models/qwen35.cpp index b188810f931..1b7796f775e 100644 --- a/src/models/qwen35.cpp +++ b/src/models/qwen35.cpp @@ -12,16 +12,23 @@ void llama_model_qwen35::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - // Mark recurrent layers (linear attention layers) + // NextN/MTP (Qwen3.5/3.6): extra decoder block appended beyond the main stack + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + + // Mark recurrent layers (linear attention layers). MTP layers are dense + // attention-only and must be flagged non-recurrent. { + const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; uint32_t full_attn_interval = 4; ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + hparams.recurrent_layer_arr[i] = (i < n_main) && ((i + 1) % full_attn_interval != 0); } } - switch (hparams.n_layer) { + switch (hparams.n_layer - hparams.nextn_predict_layers) { case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_8B : LLM_TYPE_2B; break; case 32: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_9B; break; case 64: type = LLM_TYPE_27B; break; @@ -83,6 +90,16 @@ void llama_model_qwen35::load_arch_tensors(llama_model_loader &) { layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // NextN/MTP tensors (preserved but unused) - only bound on MTP layers + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, TENSOR_NOT_REQUIRED); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); + } } } @@ -111,7 +128,9 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_out_ids = build_inp_out_ids(); - for (int il = 0; il < n_layer; ++il) { + // MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass. + const int n_transformer_layers = n_layer - (int) hparams.nextn_predict_layers; + for (int il = 0; il < n_transformer_layers; ++il) { ggml_tensor * inpSA = inpL; cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); @@ -128,7 +147,7 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); } - if (il == n_layer - 1 && inp_out_ids) { + if (il == n_transformer_layers - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -160,6 +179,9 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para } cur = inpL; + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + // Final norm cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); diff --git a/src/models/qwen35_mtp.cpp b/src/models/qwen35_mtp.cpp new file mode 100644 index 00000000000..83039e98db5 --- /dev/null +++ b/src/models/qwen35_mtp.cpp @@ -0,0 +1,207 @@ +#include "models.h" + +void llama_model_qwen35_mtp::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); + + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35_MTP requires nextn_predict_layers > 0"); + GGML_ASSERT(hparams.nextn_predict_layers <= hparams.n_layer); + + // only the MTP layers get a KV cache, trunk layers are skipped. + hparams.kv_only_nextn = true; + hparams.n_layer_kv_from_start = -1; + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = false; + } + + type = LLM_TYPE_UNKNOWN; +} + +void llama_model_qwen35_mtp::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, TENSOR_NOT_REQUIRED); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + if (output == nullptr) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + const uint32_t n_main = n_layer - hparams.nextn_predict_layers; + for (int i = 0; i < n_layer; ++i) { + if (static_cast(i) < n_main) { + continue; // trunk layer — owned by the sibling QWEN35 model + } + + auto & layer = layers[i]; + + // MTP block looks like a full-attention Qwen3.5 decoder block. + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // NextN-specific tensors that define the MTP block. + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, 0); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, 0); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, 0); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); + } +} + +std::unique_ptr llama_model_qwen35_mtp::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +// LLM_ARCH_QWEN35_MTP draft head for Qwen3.5/3.6 dense series +llama_model_qwen35_mtp::graph::graph(const llama_model & model, const llm_graph_params & params) + : llm_graph_context(params) { + GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35_MTP requires nextn_predict_layers > 0"); + GGML_ASSERT(hparams.nextn_predict_layers == 1 && "QWEN35_MTP currently only supports a single MTP block"); + + const int64_t n_embd_head = hparams.n_embd_head_v(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + + // The MTP block lives at the source file's original layer index. + const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers; + const auto & layer = model.layers[il]; + + GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); + GGML_ASSERT(layer.nextn.enorm && "MTP block missing nextn.enorm"); + GGML_ASSERT(layer.nextn.hnorm && "MTP block missing nextn.hnorm"); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + auto inp = std::make_unique(hparams.n_embd); + + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_input(inp->tokens); + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + ggml_set_input(inp->embd); + ggml_set_name(inp->embd, "mtp_h_input"); + + ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; + + ggml_tensor * h_input = inp->embd; + ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); + cb(tok_embd, "mtp_tok_embd", il); + + res->add_input(std::move(inp)); + + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); + cb(h_norm, "mtp_hnorm", il); + + ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il); + cb(e_norm, "mtp_enorm", il); + + ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, /*dim=*/ 0); + cb(concat, "mtp_concat", il); + + ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat); + cb(cur, "mtp_eh_proj", il); + + ggml_tensor * inpSA = cur; + + cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_norm", il); + + ggml_tensor * Qcur_full = build_lora_mm(layer.wq, cur, layer.wq_s); + cb(Qcur_full, "mtp_Qcur_full", il); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + 0); + Qcur = build_norm(Qcur, layer.attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "mtp_Qcur_normed", il); + + ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + ggml_element_size(Qcur_full) * n_embd_head); + gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); + cb(gate, "mtp_gate", il); + + ggml_tensor * Kcur = build_lora_mm(layer.wk, cur, layer.wk_s); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_norm(Kcur, layer.attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "mtp_Kcur_normed", il); + + ggml_tensor * Vcur = build_lora_mm(layer.wv, cur, layer.wv_s); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + cb(Vcur, "mtp_Vcur", il); + + Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + const float kq_scale = hparams.f_attention_scale == 0.0f + ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + cur = build_attn(inp_attn, + nullptr, nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "mtp_attn_pregate", il); + + cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate)); + cur = build_lora_mm(layer.wo, cur, layer.wo_s); + cb(cur, "mtp_attn_out", il); + + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "mtp_attn_residual", il); + + ggml_tensor * ffn_residual = cur; + cur = build_norm(cur, layer.attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_post_norm", il); + + cur = build_ffn(cur, + layer.ffn_up, nullptr, layer.ffn_up_s, + layer.ffn_gate, nullptr, layer.ffn_gate_s, + layer.ffn_down, nullptr, layer.ffn_down_s, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "mtp_ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_residual); + cb(cur, "mtp_post_ffn", il); + + // Pre-norm hidden state: used by the AR draft loop to seed the next MTP step. + // (In the trunk graph this is `t_h_pre_norm`; the MTP head reuses the same slot.) + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + + ggml_tensor * head_norm_w = layer.nextn.shared_head_norm + ? layer.nextn.shared_head_norm + : model.output_norm; + GGML_ASSERT(head_norm_w && "QWEN35_MTP: missing both nextn.shared_head_norm and output_norm"); + cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1); + cb(cur, "mtp_shared_head_norm", -1); + + ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output; + GGML_ASSERT(head_w && "QWEN35_MTP: missing LM head (nextn.shared_head_head or model.output)"); + cur = build_lora_mm(head_w, cur); + cb(cur, "result_output", -1); + + res->t_logits = cur; + ggml_build_forward_expand(gf, cur); +} diff --git a/src/models/qwen35moe.cpp b/src/models/qwen35moe.cpp index 8ec9b8c6f7d..43d9c7a1e3c 100644 --- a/src/models/qwen35moe.cpp +++ b/src/models/qwen35moe.cpp @@ -15,16 +15,23 @@ void llama_model_qwen35moe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - // Mark recurrent layers (linear attention layers) + // NextN/MTP (Qwen3.5/3.6): extra decoder block appended beyond the main stack + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + + // Mark recurrent layers (linear attention layers). MTP layers are dense + // attention-only and must be flagged non-recurrent. { + const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; uint32_t full_attn_interval = 4; ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + hparams.recurrent_layer_arr[i] = (i < n_main) && ((i + 1) % full_attn_interval != 0); } } - switch (hparams.n_layer) { + switch (hparams.n_layer - hparams.nextn_predict_layers) { case 40: type = LLM_TYPE_35B_A3B; break; case 48: type = LLM_TYPE_122B_A10B; break; case 60: type = LLM_TYPE_397B_A17B; break; @@ -96,6 +103,16 @@ void llama_model_qwen35moe::load_arch_tensors(llama_model_loader &) { layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0); + + // NextN/MTP tensors (preserved but unused) - only bound on MTP layers + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, TENSOR_NOT_REQUIRED); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); + } } } @@ -124,7 +141,9 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_out_ids = build_inp_out_ids(); - for (int il = 0; il < n_layer; ++il) { + // MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass. + const int n_transformer_layers = n_layer - (int) hparams.nextn_predict_layers; + for (int il = 0; il < n_transformer_layers; ++il) { ggml_tensor * inpSA = inpL; cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); @@ -141,7 +160,7 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); } - if (il == n_layer - 1 && inp_out_ids) { + if (il == n_transformer_layers - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -173,6 +192,9 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p } cur = inpL; + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + // Final norm cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); diff --git a/src/models/qwen35moe_mtp.cpp b/src/models/qwen35moe_mtp.cpp new file mode 100644 index 00000000000..9f662213bee --- /dev/null +++ b/src/models/qwen35moe_mtp.cpp @@ -0,0 +1,252 @@ +#include "models.h" + +void llama_model_qwen35moe_mtp::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); + + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35MOE_MTP requires nextn_predict_layers > 0"); + GGML_ASSERT(hparams.nextn_predict_layers <= hparams.n_layer); + GGML_ASSERT(hparams.n_expert > 0 && "QWEN35MOE_MTP requires n_expert > 0"); + + // only the MTP layers get a KV cache, trunk layers are skipped. + hparams.kv_only_nextn = true; + hparams.n_layer_kv_from_start = -1; + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = false; + } + + type = LLM_TYPE_UNKNOWN; +} + +void llama_model_qwen35moe_mtp::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, TENSOR_NOT_REQUIRED); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + if (output == nullptr) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; + + const uint32_t n_main = n_layer - hparams.nextn_predict_layers; + for (int i = 0; i < n_layer; ++i) { + if (static_cast(i) < n_main) { + continue; // trunk layer — owned by the sibling QWEN35MOE model + } + + auto & layer = layers[i]; + + // MTP block looks like a full-attention Qwen3.5 decoder block with MoE FFN. + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + + // Routed experts + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); + create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); + + // Shared experts + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0); + + // NextN-specific tensors that define the MTP block. + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, 0); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, 0); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, 0); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); + } +} + +std::unique_ptr llama_model_qwen35moe_mtp::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +// LLM_ARCH_QWEN35MOE_MTP draft head for Qwen3.5/3.6 MoE +llama_model_qwen35moe_mtp::graph::graph(const llama_model & model, const llm_graph_params & params) + : llm_graph_context(params) { + GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35MOE_MTP requires nextn_predict_layers > 0"); + GGML_ASSERT(hparams.nextn_predict_layers == 1 && "QWEN35MOE_MTP currently only supports a single MTP block"); + + const int64_t n_embd_head = hparams.n_embd_head_v(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + + const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers; + const auto & layer = model.layers[il]; + + GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); + GGML_ASSERT(layer.nextn.enorm && "MTP block missing nextn.enorm"); + GGML_ASSERT(layer.nextn.hnorm && "MTP block missing nextn.hnorm"); + GGML_ASSERT(layer.ffn_gate_inp && "MTP block missing ffn_gate_inp"); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + auto inp = std::make_unique(hparams.n_embd); + + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_input(inp->tokens); + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + ggml_set_input(inp->embd); + ggml_set_name(inp->embd, "mtp_h_input"); + + ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; + + ggml_tensor * h_input = inp->embd; + ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); + cb(tok_embd, "mtp_tok_embd", il); + + res->add_input(std::move(inp)); + + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); + cb(h_norm, "mtp_hnorm", il); + + ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il); + cb(e_norm, "mtp_enorm", il); + + ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, /*dim=*/ 0); + cb(concat, "mtp_concat", il); + + ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat); + cb(cur, "mtp_eh_proj", il); + + ggml_tensor * inpSA = cur; + + cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_norm", il); + + ggml_tensor * Qcur_full = build_lora_mm(layer.wq, cur, layer.wq_s); + cb(Qcur_full, "mtp_Qcur_full", il); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + 0); + Qcur = build_norm(Qcur, layer.attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "mtp_Qcur_normed", il); + + ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + ggml_element_size(Qcur_full) * n_embd_head); + gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); + cb(gate, "mtp_gate", il); + + ggml_tensor * Kcur = build_lora_mm(layer.wk, cur, layer.wk_s); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_norm(Kcur, layer.attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "mtp_Kcur_normed", il); + + ggml_tensor * Vcur = build_lora_mm(layer.wv, cur, layer.wv_s); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + cb(Vcur, "mtp_Vcur", il); + + Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + const float kq_scale = hparams.f_attention_scale == 0.0f + ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + cur = build_attn(inp_attn, + nullptr, nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "mtp_attn_pregate", il); + + cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate)); + cur = build_lora_mm(layer.wo, cur, layer.wo_s); + cb(cur, "mtp_attn_out", il); + + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "mtp_attn_residual", il); + + ggml_tensor * ffn_residual = cur; + cur = build_norm(cur, layer.attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_post_norm", il); + + // MoE FFN — routed experts plus gated shared expert (mirrors qwen35moe). + ggml_tensor * moe_out = + build_moe_ffn(cur, + layer.ffn_gate_inp, + layer.ffn_up_exps, + layer.ffn_gate_exps, + layer.ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + hparams.expert_weights_scale, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il, + nullptr, layer.ffn_gate_up_exps, + layer.ffn_up_exps_s, + layer.ffn_gate_exps_s, + layer.ffn_down_exps_s); + cb(moe_out, "mtp_ffn_moe_out", il); + + if (layer.ffn_up_shexp != nullptr) { + ggml_tensor * ffn_shexp = + build_ffn(cur, + layer.ffn_up_shexp, nullptr, layer.ffn_up_shexp_s, + layer.ffn_gate_shexp, nullptr, layer.ffn_gate_shexp_s, + layer.ffn_down_shexp, nullptr, layer.ffn_down_shexp_s, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "mtp_ffn_shexp", il); + + ggml_tensor * shared_gate = build_lora_mm(layer.ffn_gate_inp_shexp, cur); + shared_gate = ggml_sigmoid(ctx0, shared_gate); + cb(shared_gate, "mtp_shared_expert_gate_sigmoid", il); + + ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate); + cb(ffn_shexp, "mtp_ffn_shexp_gated", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + } else { + cur = moe_out; + } + cb(cur, "mtp_ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_residual); + cb(cur, "mtp_post_ffn", il); + + // Pre-norm hidden state: used by the AR draft loop to seed the next MTP step. + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + + ggml_tensor * head_norm_w = layer.nextn.shared_head_norm + ? layer.nextn.shared_head_norm + : model.output_norm; + GGML_ASSERT(head_norm_w && "QWEN35MOE_MTP: missing both nextn.shared_head_norm and output_norm"); + cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1); + cb(cur, "mtp_shared_head_norm", -1); + + ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output; + GGML_ASSERT(head_w && "QWEN35MOE_MTP: missing LM head (nextn.shared_head_head or model.output)"); + cur = build_lora_mm(head_w, cur); + cb(cur, "result_output", -1); + + res->t_logits = cur; + ggml_build_forward_expand(gf, cur); +} diff --git a/tests/test-llama-archs.cpp b/tests/test-llama-archs.cpp index 16af11a2862..fd0d3696d77 100644 --- a/tests/test-llama-archs.cpp +++ b/tests/test-llama-archs.cpp @@ -406,6 +406,9 @@ static bool arch_supported(const llm_arch arch) { if (arch == LLM_ARCH_DEEPSEEK2OCR) { return false; } + if (arch == LLM_ARCH_QWEN35_MTP || arch == LLM_ARCH_QWEN35MOE_MTP) { + return false; // MTP-only arch; requires a sibling trunk model and cannot run standalone. + } // FIXME some models are segfaulting with WebGPU: #ifdef GGML_USE_WEBGPU diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 1dc19536866..6678eaa9aca 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -57,6 +57,11 @@ struct server_slot { llama_context * ctx_tgt = nullptr; llama_context * ctx_dft = nullptr; + // True when this slot's speculative impl is MTP (ctx_dft is the MTP head). + // MTP needs every prefill position to carry logits=1 so the streaming + // hook in common_speculative_state_mtp::process() can read t_h_pre_norm. + bool is_mtp_enabled = false; + // multimodal mtmd_context * mctx = nullptr; @@ -238,8 +243,20 @@ struct server_slot { (ggml_time_us() - t_start) / 1000.0, n_text, (int) prompt.tokens.size()); } + bool is_mtp() const { return is_mtp_enabled; } + + // The trunk needs to emit logits at every prefill position when either: + // - the task asked for embeddings, or + // - MTP is enabled for this slot (the streaming hook in process() reads + // h_pre_norm at every prompt position). + bool need_embd() const { + GGML_ASSERT(task); + return task->need_embd() || is_mtp(); + } + // if the context does not have a memory module then all embeddings have to be computed within a single ubatch // also we cannot split if the pooling would require any past tokens + // (MTP supports splitting — uses task->need_embd() not need_embd()) bool can_split() const { GGML_ASSERT(task); @@ -779,6 +796,53 @@ struct server_context_impl { ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get()); + params_base.speculative.draft.ctx_tgt = ctx_tgt; + params_base.speculative.draft.ctx_dft = ctx_dft.get(); + } else if (params_base.speculative.type == COMMON_SPECULATIVE_TYPE_MTP) { + // MTP head lives in the *target* GGUF — load it as a sibling model + // with override_arch and feed it through the existing ctx_dft slot. + char trunk_arch[64] = {0}; + llama_model_meta_val_str(model_tgt, "general.architecture", trunk_arch, sizeof(trunk_arch)); + + const char * mtp_arch = nullptr; + if (std::string(trunk_arch) == "qwen35") { + mtp_arch = "qwen35_mtp"; + } else if (std::string(trunk_arch) == "qwen35moe") { + mtp_arch = "qwen35moe_mtp"; + } else { + SRV_ERR("MTP not supported for trunk architecture '%s'\n", trunk_arch); + return false; + } + + if (params_base.n_parallel > 1) { + SRV_ERR("MTP currently supports only n_parallel=1; got %d\n", params_base.n_parallel); + return false; + } + + SRV_INF("loading MTP head from '%s' (override_arch=%s)\n", + params_base.model.path.c_str(), mtp_arch); + + auto mparams_mtp = common_model_params_to_llama(params_base); + mparams_mtp.override_arch = mtp_arch; + + model_dft.reset(llama_model_load_from_file(params_base.model.path.c_str(), mparams_mtp)); + if (model_dft == nullptr) { + SRV_ERR("failed to load MTP head from '%s'\n", params_base.model.path.c_str()); + return false; + } + + auto cparams_mtp = common_context_params_to_llama(params_base); + cparams_mtp.n_ctx = llama_n_ctx_seq(ctx_tgt); + cparams_mtp.n_seq_max = 1; + + ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams_mtp)); + if (ctx_dft == nullptr) { + SRV_ERR("%s", "failed to create MTP context\n"); + return false; + } + + ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get()); + params_base.speculative.draft.ctx_tgt = ctx_tgt; params_base.speculative.draft.ctx_dft = ctx_dft.get(); } @@ -887,6 +951,7 @@ struct server_context_impl { slot.ctx_tgt = ctx_tgt; slot.ctx_dft = ctx_dft.get(); slot.spec = spec.get(); + slot.is_mtp_enabled = (params_base.speculative.type == COMMON_SPECULATIVE_TYPE_MTP) && (ctx_dft != nullptr); slot.n_ctx = n_ctx_slot; slot.mctx = mctx; @@ -2758,12 +2823,14 @@ struct server_context_impl { break; } - // embedding requires all tokens in the batch to be output + // embedding requires all tokens in the batch to be output; + // MTP also wants logits at every prompt position so the + // streaming hook can mirror t_h_pre_norm into ctx_dft. common_batch_add(batch, cur_tok, slot.prompt.tokens.pos_next(), { slot.id }, - slot.task->need_embd()); + slot.need_embd()); slot.prompt.tokens.push_back(cur_tok); slot.n_prompt_tokens_processed++; @@ -2877,7 +2944,7 @@ struct server_context_impl { slot_batched->lora[alora_disabled_id].scale = alora_scale; } - llama_set_embeddings(ctx_tgt, slot_batched->task->need_embd()); + llama_set_embeddings(ctx_tgt, slot_batched->need_embd()); } if (batch.n_tokens == 0) { From 80e1f3c4eabe638d9cecc765ebb249ea85224043 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Mon, 11 May 2026 12:22:37 +0800 Subject: [PATCH 02/28] fix batch size --- common/speculative.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index b6530253ef5..dea1096e6b0 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -401,11 +401,11 @@ struct common_speculative_state_mtp : public common_speculative_impl { n_embd = llama_model_n_embd(llama_get_model(ctx_dft)); - const int32_t n_ub = (int32_t) llama_n_ubatch(ctx_dft); - batch = llama_batch_init(/*n_tokens=*/ n_ub, /*embd=*/ n_embd, /*n_seq_max=*/ 1); + const int32_t n_b = (int32_t) llama_n_batch(ctx_dft); + batch = llama_batch_init(/*n_tokens=*/ n_b, /*embd=*/ n_embd, /*n_seq_max=*/ 1); // llama_batch_init allocates only one of token/embd; MTP needs both. // TODO: fix, how to call without malloc - batch.token = (llama_token *) malloc(sizeof(llama_token) * n_ub); + batch.token = (llama_token *) malloc(sizeof(llama_token) * n_b); smpls.resize(n_seq); for (auto & s : smpls) { From 8d16341853b0b616d6286b2e3c84b16147de5420 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Mon, 11 May 2026 17:09:05 +0800 Subject: [PATCH 03/28] rename files --- src/models/{qwen35_mtp.cpp => qwen35-mtp.cpp} | 0 src/models/{qwen35moe_mtp.cpp => qwen35moe-mtp.cpp} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename src/models/{qwen35_mtp.cpp => qwen35-mtp.cpp} (100%) rename src/models/{qwen35moe_mtp.cpp => qwen35moe-mtp.cpp} (100%) diff --git a/src/models/qwen35_mtp.cpp b/src/models/qwen35-mtp.cpp similarity index 100% rename from src/models/qwen35_mtp.cpp rename to src/models/qwen35-mtp.cpp diff --git a/src/models/qwen35moe_mtp.cpp b/src/models/qwen35moe-mtp.cpp similarity index 100% rename from src/models/qwen35moe_mtp.cpp rename to src/models/qwen35moe-mtp.cpp From 5e1965d68b0018a4605b14b0d994143169a4d30b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 11 May 2026 17:26:03 +0300 Subject: [PATCH 04/28] cont : simplify (#7) --- common/speculative.cpp | 344 ++++++++++++++------------------ tools/server/server-context.cpp | 14 +- 2 files changed, 157 insertions(+), 201 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index dea1096e6b0..bbddf34382a 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -379,22 +379,14 @@ struct common_speculative_state_mtp : public common_speculative_impl { // The last h-row of one process() call needs the first token of the NEXT // call to pair with, so it's stashed here until that next call fires. std::vector> pending_h; // [n_seq][n_embd] - std::vector pending_pos; // [n_seq] - std::vector last_n_drafted; - std::vector last_n_accepted; - - // Number of trunk output rows produced by the most recent process() call. - // Used by draft() for the first AR step (when last_n_accepted is -1) to - // pick the last prefill row out of ctx_tgt's pre-norm buffer. - std::vector last_trunk_n_outputs; + std::vector i_batch_beg; + std::vector i_batch_end; common_speculative_state_mtp(const common_params_speculative & params, uint32_t n_seq) : common_speculative_impl(COMMON_SPECULATIVE_TYPE_MTP, n_seq) , params(params.draft) { - GGML_ASSERT(n_seq == 1 && "MTP currently supports only single-sequence speculation"); - auto * ctx_tgt = this->params.ctx_tgt; auto * ctx_dft = this->params.ctx_dft; GGML_ASSERT(ctx_tgt && ctx_dft && "MTP requires ctx_tgt and ctx_dft to be set"); @@ -411,7 +403,7 @@ struct common_speculative_state_mtp : public common_speculative_impl { for (auto & s : smpls) { common_params_sampling sparams; sparams.no_perf = false; - sparams.top_k = 1; + sparams.top_k = 10; sparams.samplers = { COMMON_SAMPLER_TYPE_TOP_K }; s.reset(common_sampler_init(llama_get_model(ctx_dft), sparams)); } @@ -420,11 +412,9 @@ struct common_speculative_state_mtp : public common_speculative_impl { llama_set_embeddings_pre_norm(ctx_dft, true); pending_h.assign(n_seq, std::vector(n_embd, 0.0f)); - pending_pos.assign(n_seq, -1); - last_n_drafted.assign(n_seq, 0); - last_n_accepted.assign(n_seq, -1); - last_trunk_n_outputs.assign(n_seq, 0); + i_batch_beg.assign(n_seq, -1); + i_batch_end.assign(n_seq, -1); } ~common_speculative_state_mtp() override { @@ -436,12 +426,6 @@ struct common_speculative_state_mtp : public common_speculative_impl { } void begin(llama_seq_id seq_id, const llama_tokens & prompt) override { - GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < pending_pos.size()); - - last_n_accepted[seq_id] = -1; - last_n_drafted [seq_id] = 0; - pending_pos [seq_id] = -1; - const int32_t N = (int32_t) prompt.size(); if (N <= 0) { return; @@ -462,231 +446,207 @@ struct common_speculative_state_mtp : public common_speculative_impl { return true; } - // Single-seq for now (asserted in ctor). Future: bucket by seq_id. - const llama_seq_id seq_id = 0; - // TODO: how to make it work with vision tokens? if (batch_in.token == nullptr || batch_in.embd != nullptr) { - pending_pos[seq_id] = -1; return true; } - auto * ctx_tgt = this->params.ctx_tgt; - auto * ctx_dft = this->params.ctx_dft; + const int32_t n_tokens = batch_in.n_tokens; - const int32_t n_rows = batch_in.n_tokens; - const llama_pos pos_start = batch_in.pos[0]; + // remember the frist and last batch index for each sequence + std::fill(i_batch_beg.begin(), i_batch_beg.end(), -1); + std::fill(i_batch_end.begin(), i_batch_end.end(), -1); - const llama_pos pos_max_dft = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id); - if (pos_start <= pos_max_dft) { - return true; - } + for (int k = 0; k < n_tokens; ++k) { + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + GGML_ASSERT(batch_in.n_seq_id[k] == 1); - // Stale pending: discard if the new batch doesn't start one past it. - const bool pending_continues = - pending_pos[seq_id] >= 0 && pending_pos[seq_id] + 1 == pos_start; - if (pending_pos[seq_id] >= 0 && !pending_continues) { - pending_pos[seq_id] = -1; + if (batch_in.seq_id[k][0] == seq_id) { + i_batch_end[seq_id] = k; + if (i_batch_beg[seq_id] < 0) { + i_batch_beg[seq_id] = k; + } + } + } } - // Build a paired hook batch: - // row 0 = (pending_h, batch_in.token[0]) at pos_start if pending_continues - // rows 1..n_rows-1 = (h_k from this batch, batch_in.token[k+1]) at pos[k+1] - // The last h-row (h_{n_rows-1}) is stashed as the new pending and is *not* - // decoded this call — it waits for the next batch's first token to pair. + auto * ctx_tgt = this->params.ctx_tgt; + auto * ctx_dft = this->params.ctx_dft; + const size_t row_bytes = (size_t) n_embd * sizeof(float); common_batch_clear(batch); - int out_idx = 0; - - auto add_pair = [&](const float * h_row, llama_token tok, llama_pos pos) { - std::memcpy(batch.embd + (size_t) out_idx * n_embd, h_row, row_bytes); - batch.token [out_idx] = tok; - batch.pos [out_idx] = pos; - batch.n_seq_id[out_idx] = 1; - batch.seq_id [out_idx][0] = seq_id; - batch.logits [out_idx] = 0; - ++out_idx; - }; - if (pending_continues) { - add_pair(pending_h[seq_id].data(), batch_in.token[0], pos_start); + for (int k = 0; k < n_tokens; ++k) { + common_batch_add(batch, batch_in.token[k], batch_in.pos[k], { batch_in.seq_id[k][0] }, 0); } - // TODO: is there is a fast way to build this batch? - for (int k = 0; k + 1 < n_rows; ++k) { - if (batch_in.logits[k] == 0) { - LOG_WRN("%s: batch_in.logits[%d] == 0 (need_embd / logits=1 missing on prefill); stopping hook at this row\n", - __func__, k); - break; - } - const float * h_k = llama_get_embeddings_pre_norm_ith(ctx_tgt, k); - if (h_k == nullptr) { - LOG_WRN("%s: ctx_tgt has no pre-norm row at i=%d; stopping hook\n", __func__, k); - break; - } - add_pair(h_k, batch_in.token[k + 1], batch_in.pos[k + 1]); - } + // shift the tgt embeddings to the right by one position + // assumes that the tokens in the batch are sequential for each sequence + // i.e. we cannot have seq_id like this: [0, 0, 0, 1, 1, 0, 1, 1] + // ^--- this is a problem + // TODO:this is generally true, but would be nice to assert it + { + const float * h_tgt = llama_get_embeddings_pre_norm(ctx_tgt); + std::memcpy(batch.embd + (size_t) 1 * n_embd, h_tgt, row_bytes * (n_tokens-1)); + + //{ + // // string with seq_ids in the batch + // std::stringstream ss; + // for (int i = 0; i < n_tokens; ++i) { + // ss << batch_in.seq_id[i][0] << ","; + // } + // LOG_WRN("%s: batch_in.seq_id = %s\n", __func__, ss.str().c_str()); + //} + } + + // fill the pending embeddings from a previous run + auto set_h = [&](int idx, const float * h_row) { + std::memcpy(batch.embd + (size_t) idx * n_embd, h_row, row_bytes); + }; - if (out_idx > 0) { - batch.n_tokens = out_idx; - const int32_t rc = llama_decode(ctx_dft, batch); - if (rc != 0) { - LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d, n=%d)\n", - __func__, (int) rc, (int) pos_start, out_idx); - return false; + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + if (i_batch_beg[seq_id] < 0) { + continue; } + + set_h(i_batch_beg[seq_id], pending_h[seq_id].data()); } - // last_n_accepted < 0) can find the last pre-norm row of this batch. - // We assume every batch position has logits=1 (server sets need_embd - // for MTP slots) → n_outputs == n_tokens. - last_trunk_n_outputs[seq_id] = n_rows; + const int32_t rc = llama_decode(ctx_dft, batch); + if (rc != 0) { + LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n", __func__, (int) rc, (int) batch_in.pos[0]); + return false; + } - // Stash the last h-row (h_{n_rows-1}) as the new pending for the next - // process() call's first token to pair with. - if (batch_in.logits[n_rows - 1] != 0) { - const float * h_last = llama_get_embeddings_pre_norm_ith(ctx_tgt, n_rows - 1); - if (h_last != nullptr) { - std::memcpy(pending_h[seq_id].data(), h_last, row_bytes); - pending_pos[seq_id] = batch_in.pos[n_rows - 1]; - } else { - pending_pos[seq_id] = -1; + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + if (i_batch_end[seq_id] < 0) { + continue; } - } else { - // No trunk output at the tail — can't carry over. - pending_pos[seq_id] = -1; + + const float * h_last = llama_get_embeddings_pre_norm_ith(ctx_tgt, i_batch_end[seq_id]); + std::memcpy(pending_h[seq_id].data(), h_last, row_bytes); } return true; } void draft(common_speculative_draft_params_vec & dparams) override { - // Single-seq for now (asserted in ctor). Future: iterate over dparams. - const llama_seq_id seq_id = 0; - if ((size_t) seq_id >= dparams.size()) { - return; + auto & ctx_dft = params.ctx_dft; + + common_batch_clear(batch); + + // keep track of which sequences are still drafting + int n_drafting = 0; + std::vector drafting(n_seq); + + const float * h_row = nullptr; + const size_t row_bytes = (size_t) n_embd * sizeof(float); + + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + auto & dp = dparams[seq_id]; + + if (!dp.drafting) { + continue; + } + + n_drafting++; + drafting[seq_id] = true; + common_sampler_reset(smpls[seq_id].get()); + + common_batch_add(batch, dp.id_last, dp.n_past, { seq_id }, true); + + h_row = pending_h[seq_id].data(); + std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes); } - auto & dp = dparams[seq_id]; - if (!dp.drafting) { + + int ret = llama_decode(ctx_dft, batch); + if (ret != 0) { + LOG_WRN("%s: llama_decode returned %d\n", __func__, ret); return; } - auto * ctx_tgt = this->params.ctx_tgt; - auto * ctx_dft = this->params.ctx_dft; - auto * smpl = smpls[seq_id].get(); - - GGML_ASSERT(dp.result != nullptr); - auto & draft_tokens = *dp.result; - draft_tokens.clear(); - - if (last_n_drafted[seq_id] > 0) { - const int32_t n_to_drop = (int32_t) last_n_drafted[seq_id] - 1; - if (n_to_drop > 0) { - const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id); - if (pos_max >= 0) { - const llama_pos drop_from = pos_max - n_to_drop + 1; - llama_memory_seq_rm(llama_get_memory(ctx_dft), seq_id, drop_from, -1); + int i = 0; + + while (n_drafting > 0) { + int i_batch = 0; + + common_batch_clear(batch); + + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + if (!drafting[seq_id]) { + continue; } - } - last_n_drafted[seq_id] = 0; - last_n_accepted[seq_id] = 0; - } - // Effective draft length: min(global cap, per-sequence override). - int32_t n_max = std::max(1, params.n_max); - if (dp.n_max > 0) { - n_max = std::min(n_max, dp.n_max); - } + auto * smpl = smpls[seq_id].get(); - const size_t row_bytes = (size_t) n_embd * sizeof(float); + common_sampler_sample(smpl, ctx_dft, i_batch, true); + h_row = llama_get_embeddings_pre_norm_ith(ctx_dft, i_batch); + ++i_batch; - common_sampler_reset(smpl); + const auto * cur_p = common_sampler_get_candidates(smpl, true); - llama_token cond_tok = dp.id_last; - llama_pos pos = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id) + 1; + for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { + LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", + seq_id, k, i, cur_p->data[k].id, cur_p->data[k].p, + common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str()); + } - for (int32_t k = 0; k < n_max; ++k) { - const float * h_row = nullptr; + // add drafted token for each sequence + const llama_token id = cur_p->data[0].id; - if (k == 0) { - // Condition on the trunk's pre-norm row. - int32_t row_idx; - if (last_n_accepted[seq_id] < 0) { - // First draft after begin(): use the last prefill row. - row_idx = last_trunk_n_outputs[seq_id] - 1; - } else { - // After accept(n_accepted): row of the next conditioning - // position in the trunk's verify batch. - row_idx = last_n_accepted[seq_id]; + // only collect very high-confidence draft tokens + if (cur_p->data[0].p < params.p_min) { + drafting[seq_id] = false; + n_drafting--; + + continue; } - if (row_idx < 0) { - LOG_WRN("%s: no trunk pre-norm row available (row_idx=%d); stopping chain\n", - __func__, row_idx); - break; + + common_sampler_accept(smpl, id, true); + + auto & dp = dparams.at(seq_id); + auto & result = *dp.result; + + result.push_back(id); + + if ((params.n_max <= (int) result.size()) || + (dp.n_max > 0 && dp.n_max <= (int) result.size())) { + drafting[seq_id] = false; + n_drafting--; + continue; } - h_row = llama_get_embeddings_pre_norm_ith(ctx_tgt, row_idx); - } else { - // AR step: condition on the MTP head's own pre-norm row from - // the just-completed single-token decode. n_outputs=1 there, - // so the row is at batch position 0. - h_row = llama_get_embeddings_pre_norm_ith(ctx_dft, 0); + + common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true); + std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes); } - if (h_row == nullptr) { - LOG_WRN("%s: missing pre-norm row at k=%d; stopping chain\n", __func__, k); + if (batch.n_tokens == 0) { break; } - // 1-token batch carrying both (token, h_pre_norm). - common_batch_clear(batch); - std::memcpy(batch.embd, h_row, row_bytes); - batch.token [0] = cond_tok; - batch.pos [0] = pos; - batch.n_seq_id[0] = 1; - batch.seq_id [0][0] = seq_id; - batch.logits [0] = 1; // need logits for sampling - batch.n_tokens = 1; - - const int32_t rc = llama_decode(ctx_dft, batch); - if (rc != 0) { - LOG_WRN("%s: llama_decode(ctx_dft) failed rc=%d at k=%d; stopping chain\n", - __func__, rc, k); + // evaluate the drafted tokens on the draft model + ret = llama_decode(ctx_dft, batch); + if (ret != 0) { + LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret); break; } - const llama_token best = common_sampler_sample(smpl, ctx_dft, 0); - common_sampler_accept(smpl, best, /*is_generated=*/ false); - draft_tokens.push_back(best); - cond_tok = best; - ++pos; + ++i; } - last_n_drafted[seq_id] = (uint16_t) draft_tokens.size(); - } - - void accept(llama_seq_id seq_id, uint16_t n_accepted) override { - GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < last_n_drafted.size()); - - auto * ctx_dft = this->params.ctx_dft; - - const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id); - const int32_t n_drafted_last = (int32_t) last_n_drafted[seq_id]; - - const int32_t n_to_drop = std::max(0, n_drafted_last - (int32_t) n_accepted - 1); - - if (pos_max < 0) { - last_n_accepted[seq_id] = (int32_t) n_accepted; - return; - } + for (auto & dp : dparams) { + if (!dp.drafting) { + continue; + } - if (n_to_drop > 0) { - const llama_pos drop_from = pos_max - n_to_drop + 1; - llama_memory_seq_rm(llama_get_memory(ctx_dft), seq_id, drop_from, -1); + if (dp.result->size() < (size_t) params.n_min) { + dp.result->clear(); + } } + } - last_n_drafted [seq_id] = 0; - last_n_accepted[seq_id] = (int32_t) n_accepted; + void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { } }; diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 6678eaa9aca..51770c73b74 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -798,7 +798,8 @@ struct server_context_impl { params_base.speculative.draft.ctx_tgt = ctx_tgt; params_base.speculative.draft.ctx_dft = ctx_dft.get(); - } else if (params_base.speculative.type == COMMON_SPECULATIVE_TYPE_MTP) { + } else if (std::find(params_base.speculative.types.begin(), params_base.speculative.types.end(), + COMMON_SPECULATIVE_TYPE_MTP) != params_base.speculative.types.end()) { // MTP head lives in the *target* GGUF — load it as a sibling model // with override_arch and feed it through the existing ctx_dft slot. char trunk_arch[64] = {0}; @@ -814,11 +815,6 @@ struct server_context_impl { return false; } - if (params_base.n_parallel > 1) { - SRV_ERR("MTP currently supports only n_parallel=1; got %d\n", params_base.n_parallel); - return false; - } - SRV_INF("loading MTP head from '%s' (override_arch=%s)\n", params_base.model.path.c_str(), mtp_arch); @@ -832,8 +828,6 @@ struct server_context_impl { } auto cparams_mtp = common_context_params_to_llama(params_base); - cparams_mtp.n_ctx = llama_n_ctx_seq(ctx_tgt); - cparams_mtp.n_seq_max = 1; ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams_mtp)); if (ctx_dft == nullptr) { @@ -951,7 +945,9 @@ struct server_context_impl { slot.ctx_tgt = ctx_tgt; slot.ctx_dft = ctx_dft.get(); slot.spec = spec.get(); - slot.is_mtp_enabled = (params_base.speculative.type == COMMON_SPECULATIVE_TYPE_MTP) && (ctx_dft != nullptr); + slot.is_mtp_enabled = (std::find(params_base.speculative.types.begin(), params_base.speculative.types.end(), + COMMON_SPECULATIVE_TYPE_MTP) != params_base.speculative.types.end()) + && (ctx_dft != nullptr); slot.n_ctx = n_ctx_slot; slot.mctx = mctx; From 89f6e0df5e2d601eac27cb210ed15027aaf86c22 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Wed, 13 May 2026 11:12:20 +0800 Subject: [PATCH 05/28] MTP: clean-up (#9) * MTP: clean-up * review: use llama_context_type instead of llama_graph_type * review: remove llama_model_has_mtp * review: fix convert issues * convert: fix pycheck * review: formatting * use `mtp-` for identifying mtp models * convert: fix mtp conversion --- common/arg.cpp | 27 ++- common/download.cpp | 55 ++++-- common/download.h | 7 +- common/speculative.cpp | 2 +- conversion/base.py | 1 + conversion/qwen.py | 44 ++++- convert_hf_to_gguf.py | 22 +++ include/llama.h | 6 + src/llama-arch.cpp | 2 - src/llama-arch.h | 2 - src/llama-context.cpp | 27 ++- src/llama-cparams.h | 1 + src/llama-graph.h | 1 + src/llama-memory.h | 3 + src/llama-model.cpp | 36 ++-- src/models/models.h | 30 +--- src/models/qwen35-mtp.cpp | 207 --------------------- src/models/qwen35.cpp | 254 +++++++++++++++++++++----- src/models/qwen35moe-mtp.cpp | 252 -------------------------- src/models/qwen35moe.cpp | 306 +++++++++++++++++++++++++++----- tests/test-llama-archs.cpp | 6 +- tools/server/server-context.cpp | 39 ++-- 22 files changed, 685 insertions(+), 645 deletions(-) delete mode 100644 src/models/qwen35-mtp.cpp delete mode 100644 src/models/qwen35moe-mtp.cpp diff --git a/common/arg.cpp b/common/arg.cpp index 2129a9c7266..747a9e81990 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -337,11 +337,15 @@ static bool common_params_handle_remote_preset(common_params & params, llama_exa struct handle_model_result { bool found_mmproj = false; common_params_model mmproj; + + bool found_mtp = false; + common_params_model mtp; }; static handle_model_result common_params_handle_model(struct common_params_model & model, const std::string & bearer_token, - bool offline) { + bool offline, + bool search_mtp = false) { handle_model_result result; if (!model.docker_repo.empty()) { @@ -356,7 +360,7 @@ static handle_model_result common_params_handle_model(struct common_params_model common_download_opts opts; opts.bearer_token = bearer_token; opts.offline = offline; - auto download_result = common_download_model(model, opts, true); + auto download_result = common_download_model(model, opts, true, search_mtp); if (download_result.model_path.empty()) { throw std::runtime_error("failed to download model from Hugging Face"); @@ -369,6 +373,11 @@ static handle_model_result common_params_handle_model(struct common_params_model result.found_mmproj = true; result.mmproj.path = download_result.mmproj_path; } + + if (!download_result.mtp_path.empty()) { + result.found_mtp = true; + result.mtp.path = download_result.mtp_path; + } } else if (!model.url.empty()) { if (model.path.empty()) { auto f = string_split(model.url, '#').front(); @@ -436,7 +445,11 @@ static bool parse_bool_value(const std::string & value) { // void common_params_handle_models(common_params & params, llama_example curr_ex) { - auto res = common_params_handle_model(params.model, params.hf_token, params.offline); + const bool spec_type_mtp = std::find(params.speculative.types.begin(), + params.speculative.types.end(), + COMMON_SPECULATIVE_TYPE_MTP) != params.speculative.types.end(); + + auto res = common_params_handle_model(params.model, params.hf_token, params.offline, spec_type_mtp); if (params.no_mmproj) { params.mmproj = {}; } else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) { @@ -450,6 +463,14 @@ void common_params_handle_models(common_params & params, llama_example curr_ex) break; } } + // when --spec-type mtp is set and no draft model was provided explicitly, + // fall back to the MTP head discovered alongside the -hf model + if (spec_type_mtp && res.found_mtp && + params.speculative.draft.mparams.path.empty() && + params.speculative.draft.mparams.hf_repo.empty() && + params.speculative.draft.mparams.url.empty()) { + params.speculative.draft.mparams.path = res.mtp.path; + } common_params_handle_model(params.speculative.draft.mparams, params.hf_token, params.offline); common_params_handle_model(params.vocoder.model, params.hf_token, params.offline); } diff --git a/common/download.cpp b/common/download.cpp index 0bf12ad4a3b..f3dacb7e3e0 100644 --- a/common/download.cpp +++ b/common/download.cpp @@ -566,8 +566,11 @@ static hf_cache::hf_files get_split_files(const hf_cache::hf_files & files, return result; } -static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files, - const std::string & model) { +// pick the best sibling GGUF whose filename contains `keyword` (e.g. "mmproj" / "mtp"), +// preferring deeper shared directory prefix with the model, then closest quantization +static hf_cache::hf_file find_best_sibling(const hf_cache::hf_files & files, + const std::string & model, + const std::string & keyword) { hf_cache::hf_file best; size_t best_depth = 0; int best_diff = 0; @@ -579,20 +582,20 @@ static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files, for (const auto & f : files) { if (!string_ends_with(f.path, ".gguf") || - f.path.find("mmproj") == std::string::npos) { + f.path.find(keyword) == std::string::npos) { continue; } - auto mmproj_parts = string_split(f.path, '/'); - auto mmproj_dir = mmproj_parts.end() - 1; + auto sib_parts = string_split(f.path, '/'); + auto sib_dir = sib_parts.end() - 1; auto [_, dir] = std::mismatch(model_parts.begin(), model_dir, - mmproj_parts.begin(), mmproj_dir); - if (dir != mmproj_dir) { + sib_parts.begin(), sib_dir); + if (dir != sib_dir) { continue; } - size_t depth = dir - mmproj_parts.begin(); + size_t depth = dir - sib_parts.begin(); auto bits = extract_quant_bits(f.path); auto diff = std::abs(bits - model_bits); @@ -606,6 +609,16 @@ static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files, return best; } +static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files, + const std::string & model) { + return find_best_sibling(files, model, "mmproj"); +} + +static hf_cache::hf_file find_best_mtp(const hf_cache::hf_files & files, + const std::string & model) { + return find_best_sibling(files, model, "mtp-"); +} + static bool gguf_filename_is_model(const std::string & filepath) { if (!string_ends_with(filepath, ".gguf")) { return false; @@ -617,7 +630,8 @@ static bool gguf_filename_is_model(const std::string & filepath) { } return filename.find("mmproj") == std::string::npos && - filename.find("imatrix") == std::string::npos; + filename.find("imatrix") == std::string::npos && + filename.find("mtp-") == std::string::npos; } static hf_cache::hf_file find_best_model(const hf_cache::hf_files & files, @@ -673,11 +687,13 @@ struct hf_plan { hf_cache::hf_file primary; hf_cache::hf_files model_files; hf_cache::hf_file mmproj; + hf_cache::hf_file mtp; }; static hf_plan get_hf_plan(const common_params_model & model, const common_download_opts & opts, - bool download_mmproj) { + bool download_mmproj, + bool download_mtp) { hf_plan plan; hf_cache::hf_files all; @@ -723,6 +739,10 @@ static hf_plan get_hf_plan(const common_params_model & model, plan.mmproj = find_best_mmproj(all, primary.path); } + if (download_mtp) { + plan.mtp = find_best_mtp(all, primary.path); + } + return plan; } @@ -756,7 +776,8 @@ static std::vector get_url_tasks(const common_params_model & mode common_download_model_result common_download_model(const common_params_model & model, const common_download_opts & opts, - bool download_mmproj) { + bool download_mmproj, + bool download_mtp) { common_download_model_result result; std::vector tasks; hf_plan hf; @@ -764,13 +785,16 @@ common_download_model_result common_download_model(const common_params_model & bool is_hf = !model.hf_repo.empty(); if (is_hf) { - hf = get_hf_plan(model, opts, download_mmproj); + hf = get_hf_plan(model, opts, download_mmproj, download_mtp); for (const auto & f : hf.model_files) { tasks.push_back({f.url, f.local_path}); } if (!hf.mmproj.path.empty()) { tasks.push_back({hf.mmproj.url, hf.mmproj.local_path}); } + if (!hf.mtp.path.empty()) { + tasks.push_back({hf.mtp.url, hf.mtp.local_path}); + } } else if (!model.url.empty()) { tasks = get_url_tasks(model); } else { @@ -807,6 +831,10 @@ common_download_model_result common_download_model(const common_params_model & if (!hf.mmproj.path.empty()) { result.mmproj_path = hf_cache::finalize_file(hf.mmproj); } + + if (!hf.mtp.path.empty()) { + result.mtp_path = hf_cache::finalize_file(hf.mtp); + } } else { result.model_path = model.path; } @@ -946,7 +974,8 @@ std::vector common_list_cached_models() { for (const auto & f : files) { auto split = get_gguf_split_info(f.path); if (split.index != 1 || split.tag.empty() || - split.prefix.find("mmproj") != std::string::npos) { + split.prefix.find("mmproj") != std::string::npos || + split.prefix.find("MTP") != std::string::npos) { continue; } if (seen.insert(f.repo_id + ":" + split.tag).second) { diff --git a/common/download.h b/common/download.h index edc3e9f1a71..4a169ef7796 100644 --- a/common/download.h +++ b/common/download.h @@ -59,6 +59,7 @@ struct common_download_opts { struct common_download_model_result { std::string model_path; std::string mmproj_path; + std::string mtp_path; }; // Download model from HuggingFace repo or URL @@ -83,12 +84,14 @@ struct common_download_model_result { // when opts.offline=true, no network requests are made // when download_mmproj=true, searches for mmproj in same directory as model or any parent directory // then with the closest quantization bits +// when download_mtp=true, applies the same sibling search for an MTP-head GGUF // -// returns result with model_path and mmproj_path (empty on failure) +// returns result with model_path, mmproj_path and mtp_path (empty when not found / on failure) common_download_model_result common_download_model( const common_params_model & model, const common_download_opts & opts = {}, - bool download_mmproj = false + bool download_mmproj = false, + bool download_mtp = false ); // returns list of cached models diff --git a/common/speculative.cpp b/common/speculative.cpp index bbddf34382a..f064009a225 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -1198,7 +1198,7 @@ common_speculative * common_speculative_init(common_params_speculative & params, LOG_WRN("%s: draft model is not specified - cannot use 'draft' type\n", __func__); has_draft_simple = false; } - } else if (has_draft_model_path) { + } else if (has_draft_model_path && !has_mtp && !has_draft_eagle3) { LOG_WRN("%s: draft model is specified but 'draft' speculative type is not explicitly enabled - enabling it\n", __func__); has_draft_simple = true; } diff --git a/conversion/base.py b/conversion/base.py index d89d32fe150..3c4be034154 100644 --- a/conversion/base.py +++ b/conversion/base.py @@ -91,6 +91,7 @@ class ModelBase: gguf_writer: gguf.GGUFWriter model_name: str | None metadata_override: Path | None + metadata: gguf.Metadata dir_model_card: Path remote_hf_model_id: str | None diff --git a/conversion/qwen.py b/conversion/qwen.py index e8b49a22bdb..78c8293b86b 100644 --- a/conversion/qwen.py +++ b/conversion/qwen.py @@ -548,22 +548,54 @@ class _Qwen35MtpMixin: block_count: int tensor_map: gguf.TensorNameMap + mtp_only: bool = False + no_mtp: bool = False + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.block_count = self.hparams["num_hidden_layers"] + self.hparams.get("mtp_num_hidden_layers", 0) + self.block_count = self.hparams["num_hidden_layers"] + if not self.no_mtp: + self.block_count += self.hparams.get("mtp_num_hidden_layers", 0) self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + @classmethod + def filter_tensors(cls, item): + name, _ = item + if name.startswith("mtp."): + if cls.no_mtp: + return None + return item + if cls.mtp_only: + canonical = name.replace("language_model.", "") + keep = canonical in ( + "model.embed_tokens.weight", "model.norm.weight", "lm_head.weight", + "embed_tokens.weight", "norm.weight", + ) + if not keep: + return None + return super().filter_tensors(item) # ty: ignore[unresolved-attribute] + def set_gguf_parameters(self): super().set_gguf_parameters() # ty: ignore[unresolved-attribute] + if self.no_mtp: + return if (n := self.hparams.get("mtp_num_hidden_layers", 0)) > 0: self.gguf_writer.add_nextn_predict_layers(n) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - if name.startswith("model.language_model."): - name = "model." + name[len("model.language_model."):] - elif name.startswith("language_model."): - name = name[len("language_model."):] + def prepare_metadata(self, vocab_only: bool): + from_dir = self.fname_out.is_dir() + super().prepare_metadata(vocab_only=vocab_only) # ty: ignore[unresolved-attribute] + + if not self.mtp_only or not from_dir: + return + output_type: str = self.ftype.name.partition("_")[2] # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute] + fname_default: str = gguf.naming_convention( + self.metadata.name, self.metadata.basename, self.metadata.finetune, # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute] + self.metadata.version, size_label=None, output_type=output_type, model_type=None) # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute] + self.fname_out = self.fname_out.parent / f"mtp-{fname_default}.gguf" + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: if name.startswith("mtp."): n_layer = self.hparams["num_hidden_layers"] if name.find("layers.") != -1: diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 7173f616009..ff840050861 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -117,6 +117,14 @@ def parse_args() -> argparse.Namespace: "--mmproj", action="store_true", help="(Experimental) Export multimodal projector (mmproj) for vision models. This will only work on some vision models. A prefix 'mmproj-' will be added to the output file name.", ) + parser.add_argument( + "--mtp", action="store_true", + help="(Experimental) Export only the multi-token prediction (MTP) head as a separate GGUF, suitable for use as a speculative draft. Output file name will get a '-MTP' suffix.", + ) + parser.add_argument( + "--no-mtp", action="store_true", + help="(Experimental) Exclude the multi-token prediction (MTP) head from the converted GGUF. Pair with --mtp on a second run to publish trunk and MTP as two files. Note: the split form duplicates embeddings, so the bundled default is more space-efficient overall.", + ) parser.add_argument( "--mistral-format", action="store_true", help="Whether the model is stored following the Mistral format.", @@ -233,6 +241,20 @@ def main() -> None: from conversion.mistral import MistralModel model_class = MistralModel + if args.mtp and args.no_mtp: + logger.error("--mtp and --no-mtp are mutually exclusive") + sys.exit(1) + + if args.mtp or args.no_mtp: + from conversion.qwen import _Qwen35MtpMixin + if not issubclass(model_class, _Qwen35MtpMixin): + logger.error("--mtp / --no-mtp are only supported for Qwen3.5/3.6 text variants today") + sys.exit(1) + if args.no_mtp: + model_class.no_mtp = True + if args.mtp: + model_class.mtp_only = True + model_instance = model_class(dir_model, output_type, fname_out, is_big_endian=args.bigendian, use_temp_file=args.use_temp_file, eager=args.no_lazy, diff --git a/include/llama.h b/include/llama.h index 1b896944735..b814e2c58de 100644 --- a/include/llama.h +++ b/include/llama.h @@ -198,6 +198,11 @@ extern "C" { LLAMA_SPLIT_MODE_TENSOR = 3, }; + enum llama_context_type { + LLAMA_CONTEXT_TYPE_DEFAULT = 0, + LLAMA_CONTEXT_TYPE_MTP = 1, + }; + // TODO: simplify (https://github.com/ggml-org/llama.cpp/pull/9294#pullrequestreview-2286561979) typedef struct llama_token_data { llama_token id; // token id @@ -339,6 +344,7 @@ extern "C" { int32_t n_threads; // number of threads to use for generation int32_t n_threads_batch; // number of threads to use for batch processing + enum llama_context_type ctx_type; // set the context type (e.g. MTP) enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type` enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id enum llama_attention_type attention_type; // attention type to use for embeddings diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 794666d09a4..ab4334da79b 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -41,8 +41,6 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" }, { LLM_ARCH_QWEN35, "qwen35" }, { LLM_ARCH_QWEN35MOE, "qwen35moe" }, - { LLM_ARCH_QWEN35_MTP, "qwen35_mtp" }, - { LLM_ARCH_QWEN35MOE_MTP, "qwen35moe_mtp" }, { LLM_ARCH_PHI2, "phi2" }, { LLM_ARCH_PHI3, "phi3" }, { LLM_ARCH_PHIMOE, "phimoe" }, diff --git a/src/llama-arch.h b/src/llama-arch.h index 71c2ca6e6b3..e37d548c98e 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -45,8 +45,6 @@ enum llm_arch { LLM_ARCH_QWEN3VLMOE, LLM_ARCH_QWEN35, LLM_ARCH_QWEN35MOE, - LLM_ARCH_QWEN35_MTP, - LLM_ARCH_QWEN35MOE_MTP, LLM_ARCH_PHI2, LLM_ARCH_PHI3, LLM_ARCH_PHIMOE, diff --git a/src/llama-context.cpp b/src/llama-context.cpp index aea8a0a4e81..6ecbe1b6083 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2,6 +2,7 @@ #include "ggml.h" #include "llama-arch.h" +#include "llama-graph.h" #include "llama-impl.h" #include "llama-batch.h" #include "llama-io.h" @@ -21,6 +22,14 @@ // llama_context // +static llm_graph_type ctx_type_to_graph_type(llama_context_type ctx_type) { + switch (ctx_type) { + case LLAMA_CONTEXT_TYPE_DEFAULT: return LLM_GRAPH_TYPE_DEFAULT; + case LLAMA_CONTEXT_TYPE_MTP : return LLM_GRAPH_TYPE_DECODER_MTP; + } + throw std::runtime_error("Unsupported ctx type"); +} + llama_context::llama_context( const llama_model & model, llama_context_params params) : @@ -66,6 +75,8 @@ llama_context::llama_context( cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; + cparams.ctx_type = params.ctx_type; + // Initialize backend samplers here so they are part of the sampling graph // before the reserve passes run later in this function. This avoids a later // re-reserve when graph nodes change. @@ -279,6 +290,7 @@ llama_context::llama_context( /*.type_k =*/ params.type_k, /*.type_v =*/ params.type_v, /*.swa_full =*/ params.swa_full, + /*.ctx_type= */ cparams.ctx_type, }; memory.reset(model.create_memory(params_mem, cparams)); @@ -1738,7 +1750,8 @@ int llama_context::decode(const llama_batch & batch_inp) { } ggml_status status; - const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status); + + const auto * res = process_ubatch(ubatch, ctx_type_to_graph_type(cparams.ctx_type), mctx.get(), status); if (!res) { // the last ubatch failed or was aborted -> remove all positions of that ubatch from the memory module @@ -2198,7 +2211,7 @@ ggml_cgraph * llama_context::graph_reserve( auto * res = gf_res_reserve.get(); - const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT); + const auto gparams = graph_params(res, ubatch, mctx, ctx_type_to_graph_type(cparams.ctx_type)); res->reset(); @@ -3177,7 +3190,7 @@ void llama_context::opt_epoch_iter( auto * res = gf_res_prev.get(); - const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT); + const auto gparams = graph_params(res, ubatch, mctx.get(), ctx_type_to_graph_type(cparams.ctx_type)); res->reset(); @@ -3280,6 +3293,7 @@ llama_context_params llama_context_default_params() { /*.n_seq_max =*/ 1, /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, + /*.ctx_type =*/ LLAMA_CONTEXT_TYPE_DEFAULT, /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED, /*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED, @@ -3383,6 +3397,13 @@ llama_context * llama_init_from_model( model->hparams.pooling_type, params.pooling_type); } + if (params.ctx_type == LLAMA_CONTEXT_TYPE_MTP && + model->hparams.nextn_predict_layers == 0) { + LLAMA_LOG_WARN("%s: context type MTP requested but model doesn't contain MTP layers\n", __func__); + return nullptr; + } + + try { auto * ctx = new llama_context(*model, params); return ctx; diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 1e4e9e29ed8..cbf74eba63e 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -41,6 +41,7 @@ struct llama_cparams { bool kv_unified; bool pipeline_parallel; + enum llama_context_type ctx_type; enum llama_pooling_type pooling_type; ggml_backend_sched_eval_callback cb_eval; diff --git a/src/llama-graph.h b/src/llama-graph.h index d3cd69a674c..9e55d0a675e 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -32,6 +32,7 @@ enum llm_graph_type { LLM_GRAPH_TYPE_DEFAULT, LLM_GRAPH_TYPE_ENCODER, LLM_GRAPH_TYPE_DECODER, + LLM_GRAPH_TYPE_DECODER_MTP, }; enum llm_ffn_op_type { diff --git a/src/llama-memory.h b/src/llama-memory.h index 4a157b91fdb..4ad1612e45b 100644 --- a/src/llama-memory.h +++ b/src/llama-memory.h @@ -1,6 +1,7 @@ #pragma once #include "llama.h" +#include "llama-graph.h" #include #include @@ -20,6 +21,8 @@ struct llama_memory_params { // use full-size SWA cache bool swa_full; + + llama_context_type ctx_type; }; enum llama_memory_status { diff --git a/src/llama-model.cpp b/src/llama-model.cpp index e14f375521e..5ab183271eb 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -276,10 +276,6 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params return new llama_model_qwen35(params); case LLM_ARCH_QWEN35MOE: return new llama_model_qwen35moe(params); - case LLM_ARCH_QWEN35_MTP: - return new llama_model_qwen35_mtp(params); - case LLM_ARCH_QWEN35MOE_MTP: - return new llama_model_qwen35moe_mtp(params); case LLM_ARCH_MISTRAL3: return new llama_model_mistral3(params); case LLM_ARCH_MIMO2: @@ -1419,8 +1415,7 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) { } } } - const bool partial_load = (arch == LLM_ARCH_QWEN35_MTP || arch == LLM_ARCH_QWEN35MOE_MTP); - ml.done_getting_tensors(partial_load); + ml.done_getting_tensors(); GGML_ASSERT(!(output && tok_embd && strcmp(output->name, tok_embd->name) == 0 && @@ -1961,6 +1956,12 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, // checks default: { + // The MTP head is dense-attention only on hybrid Qwen3.5/3.6, so use a plain + // attention KV cache for the MTP context instead of the hybrid wrapper. + const bool mtp_on_hybrid_qwen35 = + params.ctx_type == LLAMA_CONTEXT_TYPE_MTP && + (arch == LLM_ARCH_QWEN35 || arch == LLM_ARCH_QWEN35MOE); + if (llm_arch_is_recurrent(arch)) { res = new llama_memory_recurrent( *this, @@ -1970,7 +1971,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, std::max((uint32_t) 1, cparams.n_seq_max), cparams.n_seq_max, nullptr); - } else if (llm_arch_is_hybrid(arch)) { + } else if (llm_arch_is_hybrid(arch) && !mtp_on_hybrid_qwen35) { // The main difference between hybrid architectures is the // layer filters, so pick the right one here llama_memory_hybrid::layer_filter_cb filter_attn = nullptr; @@ -1985,6 +1986,14 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, filter_recr = [&](int32_t il) { return hparams.is_recurrent(il) && hparams.n_ff(il) == 0; }; + } else if (arch == LLM_ARCH_QWEN35 || arch == LLM_ARCH_QWEN35MOE) { + const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; + filter_attn = [&, n_main](int32_t il) { + return (uint32_t)il < n_main && !hparams.is_recurrent(il); + }; + filter_recr = [&, n_main](int32_t il) { + return (uint32_t)il < n_main && hparams.is_recurrent(il); + }; } if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { @@ -2027,6 +2036,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, } } else { llama_memory_i::layer_reuse_cb reuse = nullptr; + llama_kv_cache::layer_filter_cb filter = nullptr; if (arch == LLM_ARCH_GEMMA3N || arch == LLM_ARCH_GEMMA4) { reuse = [&](int32_t il) { @@ -2038,6 +2048,11 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, }; } + if (mtp_on_hybrid_qwen35) { + const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; + filter = [n_main](int32_t il) { return (uint32_t)il >= n_main; }; + } + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { GGML_ASSERT(hparams.is_swa_any()); @@ -2053,7 +2068,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, cparams.n_seq_max, cparams.n_ubatch, 1, - nullptr, + filter, reuse); } else { GGML_ASSERT(!hparams.is_swa_any()); @@ -2070,7 +2085,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, 1, hparams.n_swa, hparams.swa_type, - nullptr, + filter, nullptr); } } @@ -2174,6 +2189,7 @@ int32_t llama_model_n_swa(const llama_model * model) { return model->hparams.n_swa; } + uint32_t llama_model_n_cls_out(const struct llama_model * model) { return model->hparams.n_cls_out; } @@ -2341,8 +2357,6 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_QWEN3VLMOE: case LLM_ARCH_QWEN35: case LLM_ARCH_QWEN35MOE: - case LLM_ARCH_QWEN35_MTP: - case LLM_ARCH_QWEN35MOE_MTP: return LLAMA_ROPE_TYPE_IMROPE; case LLM_ARCH_GLM4: diff --git a/src/models/models.h b/src/models/models.h index 1f04d313d13..fe95b9b89ad 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -1739,6 +1739,10 @@ struct llama_model_qwen35 : public llama_model_base { const llama_model & model; }; + struct graph_mtp : public llm_graph_context { + graph_mtp(const llama_model & model, const llm_graph_params & params); + }; + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; @@ -1781,30 +1785,8 @@ struct llama_model_qwen35moe : public llama_model_base { const llama_model & model; }; - std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; -}; - - -struct llama_model_qwen35_mtp : public llama_model_base { - llama_model_qwen35_mtp(const struct llama_model_params & params) : llama_model_base(params) {} - void load_arch_hparams(llama_model_loader & ml) override; - void load_arch_tensors(llama_model_loader & ml) override; - - struct graph : public llm_graph_context { - graph(const llama_model & model, const llm_graph_params & params); - }; - - std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; -}; - - -struct llama_model_qwen35moe_mtp : public llama_model_base { - llama_model_qwen35moe_mtp(const struct llama_model_params & params) : llama_model_base(params) {} - void load_arch_hparams(llama_model_loader & ml) override; - void load_arch_tensors(llama_model_loader & ml) override; - - struct graph : public llm_graph_context { - graph(const llama_model & model, const llm_graph_params & params); + struct graph_mtp : public llm_graph_context { + graph_mtp(const llama_model & model, const llm_graph_params & params); }; std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; diff --git a/src/models/qwen35-mtp.cpp b/src/models/qwen35-mtp.cpp deleted file mode 100644 index 83039e98db5..00000000000 --- a/src/models/qwen35-mtp.cpp +++ /dev/null @@ -1,207 +0,0 @@ -#include "models.h" - -void llama_model_qwen35_mtp::load_arch_hparams(llama_model_loader & ml) { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); - - ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35_MTP requires nextn_predict_layers > 0"); - GGML_ASSERT(hparams.nextn_predict_layers <= hparams.n_layer); - - // only the MTP layers get a KV cache, trunk layers are skipped. - hparams.kv_only_nextn = true; - hparams.n_layer_kv_from_start = -1; - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = false; - } - - type = LLM_TYPE_UNKNOWN; -} - -void llama_model_qwen35_mtp::load_arch_tensors(llama_model_loader &) { - LLAMA_LOAD_LOCALS; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, TENSOR_NOT_REQUIRED); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); - if (output == nullptr) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); - } - - const uint32_t n_main = n_layer - hparams.nextn_predict_layers; - for (int i = 0; i < n_layer; ++i) { - if (static_cast(i) < n_main) { - continue; // trunk layer — owned by the sibling QWEN35 model - } - - auto & layer = layers[i]; - - // MTP block looks like a full-attention Qwen3.5 decoder block. - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - - // NextN-specific tensors that define the MTP block. - layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, 0); - layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, 0); - layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, 0); - layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); - layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); - layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); - } -} - -std::unique_ptr llama_model_qwen35_mtp::build_arch_graph(const llm_graph_params & params) const { - return std::make_unique(*this, params); -} - -// LLM_ARCH_QWEN35_MTP draft head for Qwen3.5/3.6 dense series -llama_model_qwen35_mtp::graph::graph(const llama_model & model, const llm_graph_params & params) - : llm_graph_context(params) { - GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35_MTP requires nextn_predict_layers > 0"); - GGML_ASSERT(hparams.nextn_predict_layers == 1 && "QWEN35_MTP currently only supports a single MTP block"); - - const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); - - // The MTP block lives at the source file's original layer index. - const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers; - const auto & layer = model.layers[il]; - - GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); - GGML_ASSERT(layer.nextn.enorm && "MTP block missing nextn.enorm"); - GGML_ASSERT(layer.nextn.hnorm && "MTP block missing nextn.hnorm"); - - int sections[4]; - std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); - - auto inp = std::make_unique(hparams.n_embd); - - inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - ggml_set_input(inp->tokens); - - inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); - ggml_set_input(inp->embd); - ggml_set_name(inp->embd, "mtp_h_input"); - - ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; - - ggml_tensor * h_input = inp->embd; - ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); - cb(tok_embd, "mtp_tok_embd", il); - - res->add_input(std::move(inp)); - - ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv(); - - ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); - cb(h_norm, "mtp_hnorm", il); - - ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il); - cb(e_norm, "mtp_enorm", il); - - ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, /*dim=*/ 0); - cb(concat, "mtp_concat", il); - - ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat); - cb(cur, "mtp_eh_proj", il); - - ggml_tensor * inpSA = cur; - - cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il); - cb(cur, "mtp_attn_norm", il); - - ggml_tensor * Qcur_full = build_lora_mm(layer.wq, cur, layer.wq_s); - cb(Qcur_full, "mtp_Qcur_full", il); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, - n_embd_head, n_head, n_tokens, - ggml_element_size(Qcur_full) * n_embd_head * 2, - ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, - 0); - Qcur = build_norm(Qcur, layer.attn_q_norm, nullptr, LLM_NORM_RMS, il); - cb(Qcur, "mtp_Qcur_normed", il); - - ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, - n_embd_head, n_head, n_tokens, - ggml_element_size(Qcur_full) * n_embd_head * 2, - ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, - ggml_element_size(Qcur_full) * n_embd_head); - gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); - cb(gate, "mtp_gate", il); - - ggml_tensor * Kcur = build_lora_mm(layer.wk, cur, layer.wk_s); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Kcur = build_norm(Kcur, layer.attn_k_norm, nullptr, LLM_NORM_RMS, il); - cb(Kcur, "mtp_Kcur_normed", il); - - ggml_tensor * Vcur = build_lora_mm(layer.wv, cur, layer.wv_s); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - cb(Vcur, "mtp_Vcur", il); - - Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr, - n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow); - Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr, - n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow); - - const float kq_scale = hparams.f_attention_scale == 0.0f - ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; - - cur = build_attn(inp_attn, - nullptr, nullptr, nullptr, - Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); - cb(cur, "mtp_attn_pregate", il); - - cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate)); - cur = build_lora_mm(layer.wo, cur, layer.wo_s); - cb(cur, "mtp_attn_out", il); - - cur = ggml_add(ctx0, cur, inpSA); - cb(cur, "mtp_attn_residual", il); - - ggml_tensor * ffn_residual = cur; - cur = build_norm(cur, layer.attn_post_norm, nullptr, LLM_NORM_RMS, il); - cb(cur, "mtp_attn_post_norm", il); - - cur = build_ffn(cur, - layer.ffn_up, nullptr, layer.ffn_up_s, - layer.ffn_gate, nullptr, layer.ffn_gate_s, - layer.ffn_down, nullptr, layer.ffn_down_s, - nullptr, - LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(cur, "mtp_ffn_out", il); - - cur = ggml_add(ctx0, cur, ffn_residual); - cb(cur, "mtp_post_ffn", il); - - // Pre-norm hidden state: used by the AR draft loop to seed the next MTP step. - // (In the trunk graph this is `t_h_pre_norm`; the MTP head reuses the same slot.) - cb(cur, "h_pre_norm", -1); - res->t_h_pre_norm = cur; - - ggml_tensor * head_norm_w = layer.nextn.shared_head_norm - ? layer.nextn.shared_head_norm - : model.output_norm; - GGML_ASSERT(head_norm_w && "QWEN35_MTP: missing both nextn.shared_head_norm and output_norm"); - cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1); - cb(cur, "mtp_shared_head_norm", -1); - - ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output; - GGML_ASSERT(head_w && "QWEN35_MTP: missing LM head (nextn.shared_head_head or model.output)"); - cur = build_lora_mm(head_w, cur); - cb(cur, "result_output", -1); - - res->t_logits = cur; - ggml_build_forward_expand(gf, cur); -} diff --git a/src/models/qwen35.cpp b/src/models/qwen35.cpp index 1b7796f775e..e59d7f28856 100644 --- a/src/models/qwen35.cpp +++ b/src/models/qwen35.cpp @@ -15,7 +15,6 @@ void llama_model_qwen35::load_arch_hparams(llama_model_loader & ml) { // NextN/MTP (Qwen3.5/3.6): extra decoder block appended beyond the main stack ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); - hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; // Mark recurrent layers (linear attention layers). MTP layers are dense // attention-only and must be flagged non-recurrent. @@ -36,9 +35,14 @@ void llama_model_qwen35::load_arch_hparams(llama_model_loader & ml) { } } -void llama_model_qwen35::load_arch_tensors(llama_model_loader &) { +void llama_model_qwen35::load_arch_tensors(llama_model_loader & ml) { LLAMA_LOAD_LOCALS; + const uint32_t n_main = n_layer - hparams.nextn_predict_layers; + const bool mtp_only = (hparams.nextn_predict_layers > 0) && + (ml.get_weight("blk.0.attn_norm.weight") == nullptr); + const int trunk_flags = mtp_only ? TENSOR_NOT_REQUIRED : 0; + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); // output @@ -50,60 +54,85 @@ void llama_model_qwen35::load_arch_tensors(llama_model_loader &) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); } - // Calculate dimensions from hyperparameters - const int64_t head_k_dim = hparams.ssm_d_state; - const int64_t head_v_dim = hparams.ssm_d_state; - const int64_t n_k_heads = hparams.ssm_n_group; - const int64_t n_v_heads = hparams.ssm_dt_rank; - const int64_t key_dim = head_k_dim * n_k_heads; - const int64_t value_dim = head_v_dim * n_v_heads; - const int64_t conv_dim = key_dim * 2 + value_dim; + auto load_block_trunk = [&](int il, int flags) { + auto & layer = layers[il]; - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; + // Calculate dimensions from hyperparameters + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, flags); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, flags); - if (!hparams.is_recurrent(i)) { + if (!hparams.is_recurrent(il)) { // Attention layers - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, flags); // Q/K normalization for attention layers - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, flags); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, flags); } else { // Linear attention (gated delta net) specific tensors // Create tensors with calculated dimensions - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); - layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); - layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); - layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0); - layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0); - layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", il), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", il), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", il), { hparams.ssm_d_conv, conv_dim }, flags); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", il), { hparams.ssm_dt_rank }, flags); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, il), { hparams.ssm_dt_rank }, flags); + layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", il), { n_embd, n_v_heads }, flags); + layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", il), { n_embd, n_v_heads }, flags); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", il), { head_v_dim }, flags); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", il), { value_dim, n_embd }, flags); } - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - - // NextN/MTP tensors (preserved but unused) - only bound on MTP layers - if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { - layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, TENSOR_NOT_REQUIRED); - layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); - layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); - layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); - layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); - layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); - } + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", il), {n_embd, n_ff}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", il), { n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", il), {n_embd, n_ff}, flags); + }; + + auto load_block_mtp = [&](int il) { + auto & layer = layers[il]; + + // MTP block looks like a full-attention Qwen3.5 decoder block. + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, 0); + + create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", il), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", il), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", il), {n_embd, n_ff}, 0); + + // NextN-specific tensors that define the MTP block. + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", il), { 2 * n_embd, n_embd }, 0); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", il), { n_embd }, 0); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", il), { n_embd }, 0); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", il), { n_embd }, TENSOR_NOT_REQUIRED); + }; + + for (int i = 0; i < (int) n_main; ++i) { + load_block_trunk(i, trunk_flags); + } + for (int i = (int) n_main; i < n_layer; ++i) { + load_block_mtp(i); } } std::unique_ptr llama_model_qwen35::build_arch_graph(const llm_graph_params & params) const { + if (params.gtype == LLM_GRAPH_TYPE_DECODER_MTP) { + return std::make_unique(*this, params); + } return std::make_unique(*this, params); } @@ -493,3 +522,146 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_ffn(ggml_tensor * cur, cons return cur; } + +// LLM_GRAPH_TYPE_DECODER_MTP draft head for Qwen3.5/3.6 dense series +llama_model_qwen35::graph_mtp::graph_mtp(const llama_model & model, const llm_graph_params & params) + : llm_graph_context(params) { + GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35 MTP requires nextn_predict_layers > 0"); + GGML_ASSERT(hparams.nextn_predict_layers == 1 && "QWEN35 MTP currently only supports a single MTP block"); + + const int64_t n_embd_head = hparams.n_embd_head_v(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + + // The MTP block lives at the source file's original layer index. + const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers; + const auto & layer = model.layers[il]; + + GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); + GGML_ASSERT(layer.nextn.enorm && "MTP block missing nextn.enorm"); + GGML_ASSERT(layer.nextn.hnorm && "MTP block missing nextn.hnorm"); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + auto inp = std::make_unique(hparams.n_embd); + + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_input(inp->tokens); + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + ggml_set_input(inp->embd); + ggml_set_name(inp->embd, "mtp_h_input"); + + ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; + + ggml_tensor * h_input = inp->embd; + ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); + cb(tok_embd, "mtp_tok_embd", il); + + res->add_input(std::move(inp)); + + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); + cb(h_norm, "mtp_hnorm", il); + + ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il); + cb(e_norm, "mtp_enorm", il); + + ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, /*dim=*/ 0); + cb(concat, "mtp_concat", il); + + ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat); + cb(cur, "mtp_eh_proj", il); + + ggml_tensor * inpSA = cur; + + cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_norm", il); + + ggml_tensor * Qcur_full = build_lora_mm(layer.wq, cur, layer.wq_s); + cb(Qcur_full, "mtp_Qcur_full", il); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + 0); + Qcur = build_norm(Qcur, layer.attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "mtp_Qcur_normed", il); + + ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + ggml_element_size(Qcur_full) * n_embd_head); + gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); + cb(gate, "mtp_gate", il); + + ggml_tensor * Kcur = build_lora_mm(layer.wk, cur, layer.wk_s); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_norm(Kcur, layer.attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "mtp_Kcur_normed", il); + + ggml_tensor * Vcur = build_lora_mm(layer.wv, cur, layer.wv_s); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + cb(Vcur, "mtp_Vcur", il); + + Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + const float kq_scale = hparams.f_attention_scale == 0.0f + ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + cur = build_attn(inp_attn, + nullptr, nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "mtp_attn_pregate", il); + + cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate)); + cur = build_lora_mm(layer.wo, cur, layer.wo_s); + cb(cur, "mtp_attn_out", il); + + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "mtp_attn_residual", il); + + ggml_tensor * ffn_residual = cur; + cur = build_norm(cur, layer.attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_post_norm", il); + + cur = build_ffn(cur, + layer.ffn_up, nullptr, layer.ffn_up_s, + layer.ffn_gate, nullptr, layer.ffn_gate_s, + layer.ffn_down, nullptr, layer.ffn_down_s, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "mtp_ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_residual); + cb(cur, "mtp_post_ffn", il); + + // Pre-norm hidden state: used by the AR draft loop to seed the next MTP step. + // (In the trunk graph this is `t_h_pre_norm`; the MTP head reuses the same slot.) + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + + ggml_tensor * head_norm_w = layer.nextn.shared_head_norm + ? layer.nextn.shared_head_norm + : model.output_norm; + GGML_ASSERT(head_norm_w && "QWEN35 MTP: missing both nextn.shared_head_norm and output_norm"); + cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1); + cb(cur, "mtp_shared_head_norm", -1); + + ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output; + GGML_ASSERT(head_w && "QWEN35 MTP: missing LM head (nextn.shared_head_head or model.output)"); + cur = build_lora_mm(head_w, cur); + cb(cur, "result_output", -1); + + res->t_logits = cur; + ggml_build_forward_expand(gf, cur); +} diff --git a/src/models/qwen35moe-mtp.cpp b/src/models/qwen35moe-mtp.cpp deleted file mode 100644 index 9f662213bee..00000000000 --- a/src/models/qwen35moe-mtp.cpp +++ /dev/null @@ -1,252 +0,0 @@ -#include "models.h" - -void llama_model_qwen35moe_mtp::load_arch_hparams(llama_model_loader & ml) { - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); - ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); - - ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35MOE_MTP requires nextn_predict_layers > 0"); - GGML_ASSERT(hparams.nextn_predict_layers <= hparams.n_layer); - GGML_ASSERT(hparams.n_expert > 0 && "QWEN35MOE_MTP requires n_expert > 0"); - - // only the MTP layers get a KV cache, trunk layers are skipped. - hparams.kv_only_nextn = true; - hparams.n_layer_kv_from_start = -1; - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = false; - } - - type = LLM_TYPE_UNKNOWN; -} - -void llama_model_qwen35moe_mtp::load_arch_tensors(llama_model_loader &) { - LLAMA_LOAD_LOCALS; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, TENSOR_NOT_REQUIRED); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); - if (output == nullptr) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); - } - - const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; - const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; - - const uint32_t n_main = n_layer - hparams.nextn_predict_layers; - for (int i = 0; i < n_layer; ++i) { - if (static_cast(i) < n_main) { - continue; // trunk layer — owned by the sibling QWEN35MOE model - } - - auto & layer = layers[i]; - - // MTP block looks like a full-attention Qwen3.5 decoder block with MoE FFN. - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); - - // Routed experts - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); - create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); - - // Shared experts - layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0); - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0); - - // NextN-specific tensors that define the MTP block. - layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, 0); - layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, 0); - layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, 0); - layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); - layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); - layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); - } -} - -std::unique_ptr llama_model_qwen35moe_mtp::build_arch_graph(const llm_graph_params & params) const { - return std::make_unique(*this, params); -} - -// LLM_ARCH_QWEN35MOE_MTP draft head for Qwen3.5/3.6 MoE -llama_model_qwen35moe_mtp::graph::graph(const llama_model & model, const llm_graph_params & params) - : llm_graph_context(params) { - GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35MOE_MTP requires nextn_predict_layers > 0"); - GGML_ASSERT(hparams.nextn_predict_layers == 1 && "QWEN35MOE_MTP currently only supports a single MTP block"); - - const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); - - const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers; - const auto & layer = model.layers[il]; - - GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); - GGML_ASSERT(layer.nextn.enorm && "MTP block missing nextn.enorm"); - GGML_ASSERT(layer.nextn.hnorm && "MTP block missing nextn.hnorm"); - GGML_ASSERT(layer.ffn_gate_inp && "MTP block missing ffn_gate_inp"); - - int sections[4]; - std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); - - auto inp = std::make_unique(hparams.n_embd); - - inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - ggml_set_input(inp->tokens); - - inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); - ggml_set_input(inp->embd); - ggml_set_name(inp->embd, "mtp_h_input"); - - ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; - - ggml_tensor * h_input = inp->embd; - ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); - cb(tok_embd, "mtp_tok_embd", il); - - res->add_input(std::move(inp)); - - ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv(); - - ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); - cb(h_norm, "mtp_hnorm", il); - - ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il); - cb(e_norm, "mtp_enorm", il); - - ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, /*dim=*/ 0); - cb(concat, "mtp_concat", il); - - ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat); - cb(cur, "mtp_eh_proj", il); - - ggml_tensor * inpSA = cur; - - cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il); - cb(cur, "mtp_attn_norm", il); - - ggml_tensor * Qcur_full = build_lora_mm(layer.wq, cur, layer.wq_s); - cb(Qcur_full, "mtp_Qcur_full", il); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, - n_embd_head, n_head, n_tokens, - ggml_element_size(Qcur_full) * n_embd_head * 2, - ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, - 0); - Qcur = build_norm(Qcur, layer.attn_q_norm, nullptr, LLM_NORM_RMS, il); - cb(Qcur, "mtp_Qcur_normed", il); - - ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, - n_embd_head, n_head, n_tokens, - ggml_element_size(Qcur_full) * n_embd_head * 2, - ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, - ggml_element_size(Qcur_full) * n_embd_head); - gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); - cb(gate, "mtp_gate", il); - - ggml_tensor * Kcur = build_lora_mm(layer.wk, cur, layer.wk_s); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Kcur = build_norm(Kcur, layer.attn_k_norm, nullptr, LLM_NORM_RMS, il); - cb(Kcur, "mtp_Kcur_normed", il); - - ggml_tensor * Vcur = build_lora_mm(layer.wv, cur, layer.wv_s); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - cb(Vcur, "mtp_Vcur", il); - - Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr, - n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow); - Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr, - n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow); - - const float kq_scale = hparams.f_attention_scale == 0.0f - ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; - - cur = build_attn(inp_attn, - nullptr, nullptr, nullptr, - Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); - cb(cur, "mtp_attn_pregate", il); - - cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate)); - cur = build_lora_mm(layer.wo, cur, layer.wo_s); - cb(cur, "mtp_attn_out", il); - - cur = ggml_add(ctx0, cur, inpSA); - cb(cur, "mtp_attn_residual", il); - - ggml_tensor * ffn_residual = cur; - cur = build_norm(cur, layer.attn_post_norm, nullptr, LLM_NORM_RMS, il); - cb(cur, "mtp_attn_post_norm", il); - - // MoE FFN — routed experts plus gated shared expert (mirrors qwen35moe). - ggml_tensor * moe_out = - build_moe_ffn(cur, - layer.ffn_gate_inp, - layer.ffn_up_exps, - layer.ffn_gate_exps, - layer.ffn_down_exps, - nullptr, - n_expert, n_expert_used, - LLM_FFN_SILU, true, - hparams.expert_weights_scale, - LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il, - nullptr, layer.ffn_gate_up_exps, - layer.ffn_up_exps_s, - layer.ffn_gate_exps_s, - layer.ffn_down_exps_s); - cb(moe_out, "mtp_ffn_moe_out", il); - - if (layer.ffn_up_shexp != nullptr) { - ggml_tensor * ffn_shexp = - build_ffn(cur, - layer.ffn_up_shexp, nullptr, layer.ffn_up_shexp_s, - layer.ffn_gate_shexp, nullptr, layer.ffn_gate_shexp_s, - layer.ffn_down_shexp, nullptr, layer.ffn_down_shexp_s, - nullptr, - LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(ffn_shexp, "mtp_ffn_shexp", il); - - ggml_tensor * shared_gate = build_lora_mm(layer.ffn_gate_inp_shexp, cur); - shared_gate = ggml_sigmoid(ctx0, shared_gate); - cb(shared_gate, "mtp_shared_expert_gate_sigmoid", il); - - ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate); - cb(ffn_shexp, "mtp_ffn_shexp_gated", il); - - cur = ggml_add(ctx0, moe_out, ffn_shexp); - } else { - cur = moe_out; - } - cb(cur, "mtp_ffn_out", il); - - cur = ggml_add(ctx0, cur, ffn_residual); - cb(cur, "mtp_post_ffn", il); - - // Pre-norm hidden state: used by the AR draft loop to seed the next MTP step. - cb(cur, "h_pre_norm", -1); - res->t_h_pre_norm = cur; - - ggml_tensor * head_norm_w = layer.nextn.shared_head_norm - ? layer.nextn.shared_head_norm - : model.output_norm; - GGML_ASSERT(head_norm_w && "QWEN35MOE_MTP: missing both nextn.shared_head_norm and output_norm"); - cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1); - cb(cur, "mtp_shared_head_norm", -1); - - ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output; - GGML_ASSERT(head_w && "QWEN35MOE_MTP: missing LM head (nextn.shared_head_head or model.output)"); - cur = build_lora_mm(head_w, cur); - cb(cur, "result_output", -1); - - res->t_logits = cur; - ggml_build_forward_expand(gf, cur); -} diff --git a/src/models/qwen35moe.cpp b/src/models/qwen35moe.cpp index 43d9c7a1e3c..b11cdaa6edd 100644 --- a/src/models/qwen35moe.cpp +++ b/src/models/qwen35moe.cpp @@ -18,7 +18,6 @@ void llama_model_qwen35moe::load_arch_hparams(llama_model_loader & ml) { // NextN/MTP (Qwen3.5/3.6): extra decoder block appended beyond the main stack ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); - hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; // Mark recurrent layers (linear attention layers). MTP layers are dense // attention-only and must be flagged non-recurrent. @@ -39,9 +38,14 @@ void llama_model_qwen35moe::load_arch_hparams(llama_model_loader & ml) { } } -void llama_model_qwen35moe::load_arch_tensors(llama_model_loader &) { +void llama_model_qwen35moe::load_arch_tensors(llama_model_loader & ml) { LLAMA_LOAD_LOCALS; + const uint32_t n_main = n_layer - hparams.nextn_predict_layers; + const bool mtp_only = (hparams.nextn_predict_layers > 0) && + (ml.get_weight("blk.0.attn_norm.weight") == nullptr); + const int trunk_flags = mtp_only ? TENSOR_NOT_REQUIRED : 0; + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); // output @@ -53,70 +57,105 @@ void llama_model_qwen35moe::load_arch_tensors(llama_model_loader &) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); } - const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + auto load_block_trunk = [&](int il, int flags) { + auto & layer = layers[il]; - // Calculate dimensions from hyperparameters - const int64_t head_k_dim = hparams.ssm_d_state; - const int64_t head_v_dim = hparams.ssm_d_state; - const int64_t n_k_heads = hparams.ssm_n_group; - const int64_t n_v_heads = hparams.ssm_dt_rank; - const int64_t key_dim = head_k_dim * n_k_heads; - const int64_t value_dim = head_v_dim * n_v_heads; - const int64_t conv_dim = key_dim * 2 + value_dim; + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; + // Calculate dimensions from hyperparameters + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, flags); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, flags); - if (!hparams.is_recurrent(i)) { + if (!hparams.is_recurrent(il)) { // Attention layers - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, flags); // Q/K normalization for attention layers - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, flags); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, flags); } else { // Linear attention (gated delta net) specific tensors // Create tensors with calculated dimensions - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); - layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); - layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); - layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0); - layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0); - layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", il), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", il), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", il), { hparams.ssm_d_conv, conv_dim }, flags); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", il), { hparams.ssm_dt_rank }, flags); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, il), { hparams.ssm_dt_rank }, flags); + layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", il), { n_embd, n_v_heads }, flags); + layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", il), { n_embd, n_v_heads }, flags); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", il), { head_v_dim }, flags); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", il), { value_dim, n_embd }, flags); } - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); - create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); + // Routed experts + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", il), { n_embd, n_expert }, flags); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", il), { n_ff_exp, n_embd, n_expert }, flags); + create_tensor_gate_up_exps(layer, il, n_embd, n_ff_exp, n_expert, flags); // Shared experts + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", il), { n_embd }, flags); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", il), { n_embd, n_ff_shexp }, flags); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", il), { n_embd, n_ff_shexp }, flags); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", il), { n_ff_shexp, n_embd }, flags); + }; + + auto load_block_mtp = [&](int il) { + auto & layer = layers[il]; + + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; - layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0); - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0); - - // NextN/MTP tensors (preserved but unused) - only bound on MTP layers - if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { - layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, TENSOR_NOT_REQUIRED); - layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); - layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); - layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); - layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); - layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); - } + // MTP block looks like a full-attention Qwen3.5 decoder block with MoE FFN. + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, 0); + + create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, 0); + + // Routed experts + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", il), { n_embd, n_expert }, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", il), { n_ff_exp, n_embd, n_expert }, 0); + create_tensor_gate_up_exps(layer, il, n_embd, n_ff_exp, n_expert, 0); + + // Shared experts + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", il), { n_embd }, 0); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", il), { n_embd, n_ff_shexp }, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", il), { n_embd, n_ff_shexp }, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", il), { n_ff_shexp, n_embd }, 0); + + // NextN-specific tensors that define the MTP block. + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", il), { 2 * n_embd, n_embd }, 0); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", il), { n_embd }, 0); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", il), { n_embd }, 0); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", il), { n_embd }, TENSOR_NOT_REQUIRED); + }; + + for (int i = 0; i < (int) n_main; ++i) { + load_block_trunk(i, trunk_flags); + } + for (int i = (int) n_main; i < n_layer; ++i) { + load_block_mtp(i); } } std::unique_ptr llama_model_qwen35moe::build_arch_graph(const llm_graph_params & params) const { + if (params.gtype == LLM_GRAPH_TYPE_DECODER_MTP) { + return std::make_unique(*this, params); + } return std::make_unique(*this, params); } @@ -547,3 +586,178 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_ffn(ggml_tensor * cur, c return cur; } + +// LLM_GRAPH_TYPE_DECODER_MTP draft head for Qwen3.5/3.6 MoE +llama_model_qwen35moe::graph_mtp::graph_mtp(const llama_model & model, const llm_graph_params & params) + : llm_graph_context(params) { + GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35MOE MTP requires nextn_predict_layers > 0"); + GGML_ASSERT(hparams.nextn_predict_layers == 1 && "QWEN35MOE MTP currently only supports a single MTP block"); + + const int64_t n_embd_head = hparams.n_embd_head_v(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + + const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers; + const auto & layer = model.layers[il]; + + GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); + GGML_ASSERT(layer.nextn.enorm && "MTP block missing nextn.enorm"); + GGML_ASSERT(layer.nextn.hnorm && "MTP block missing nextn.hnorm"); + GGML_ASSERT(layer.ffn_gate_inp && "MTP block missing ffn_gate_inp"); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + auto inp = std::make_unique(hparams.n_embd); + + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_input(inp->tokens); + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + ggml_set_input(inp->embd); + ggml_set_name(inp->embd, "mtp_h_input"); + + ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; + + ggml_tensor * h_input = inp->embd; + ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); + cb(tok_embd, "mtp_tok_embd", il); + + res->add_input(std::move(inp)); + + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); + cb(h_norm, "mtp_hnorm", il); + + ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il); + cb(e_norm, "mtp_enorm", il); + + ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, /*dim=*/ 0); + cb(concat, "mtp_concat", il); + + ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat); + cb(cur, "mtp_eh_proj", il); + + ggml_tensor * inpSA = cur; + + cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_norm", il); + + ggml_tensor * Qcur_full = build_lora_mm(layer.wq, cur, layer.wq_s); + cb(Qcur_full, "mtp_Qcur_full", il); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + 0); + Qcur = build_norm(Qcur, layer.attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "mtp_Qcur_normed", il); + + ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + ggml_element_size(Qcur_full) * n_embd_head); + gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); + cb(gate, "mtp_gate", il); + + ggml_tensor * Kcur = build_lora_mm(layer.wk, cur, layer.wk_s); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_norm(Kcur, layer.attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "mtp_Kcur_normed", il); + + ggml_tensor * Vcur = build_lora_mm(layer.wv, cur, layer.wv_s); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + cb(Vcur, "mtp_Vcur", il); + + Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + const float kq_scale = hparams.f_attention_scale == 0.0f + ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + cur = build_attn(inp_attn, + nullptr, nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "mtp_attn_pregate", il); + + cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate)); + cur = build_lora_mm(layer.wo, cur, layer.wo_s); + cb(cur, "mtp_attn_out", il); + + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "mtp_attn_residual", il); + + ggml_tensor * ffn_residual = cur; + cur = build_norm(cur, layer.attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_post_norm", il); + + // MoE FFN — routed experts plus gated shared expert (mirrors qwen35moe). + ggml_tensor * moe_out = + build_moe_ffn(cur, + layer.ffn_gate_inp, + layer.ffn_up_exps, + layer.ffn_gate_exps, + layer.ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + hparams.expert_weights_scale, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il, + nullptr, layer.ffn_gate_up_exps, + layer.ffn_up_exps_s, + layer.ffn_gate_exps_s, + layer.ffn_down_exps_s); + cb(moe_out, "mtp_ffn_moe_out", il); + + if (layer.ffn_up_shexp != nullptr) { + ggml_tensor * ffn_shexp = + build_ffn(cur, + layer.ffn_up_shexp, nullptr, layer.ffn_up_shexp_s, + layer.ffn_gate_shexp, nullptr, layer.ffn_gate_shexp_s, + layer.ffn_down_shexp, nullptr, layer.ffn_down_shexp_s, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "mtp_ffn_shexp", il); + + ggml_tensor * shared_gate = build_lora_mm(layer.ffn_gate_inp_shexp, cur); + shared_gate = ggml_sigmoid(ctx0, shared_gate); + cb(shared_gate, "mtp_shared_expert_gate_sigmoid", il); + + ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate); + cb(ffn_shexp, "mtp_ffn_shexp_gated", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + } else { + cur = moe_out; + } + cb(cur, "mtp_ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_residual); + cb(cur, "mtp_post_ffn", il); + + // Pre-norm hidden state: used by the AR draft loop to seed the next MTP step. + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + + ggml_tensor * head_norm_w = layer.nextn.shared_head_norm + ? layer.nextn.shared_head_norm + : model.output_norm; + GGML_ASSERT(head_norm_w && "QWEN35MOE MTP: missing both nextn.shared_head_norm and output_norm"); + cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1); + cb(cur, "mtp_shared_head_norm", -1); + + ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output; + GGML_ASSERT(head_w && "QWEN35MOE MTP: missing LM head (nextn.shared_head_head or model.output)"); + cur = build_lora_mm(head_w, cur); + cb(cur, "result_output", -1); + + res->t_logits = cur; + ggml_build_forward_expand(gf, cur); +} diff --git a/tests/test-llama-archs.cpp b/tests/test-llama-archs.cpp index fd0d3696d77..03d7c19c78b 100644 --- a/tests/test-llama-archs.cpp +++ b/tests/test-llama-archs.cpp @@ -406,11 +406,7 @@ static bool arch_supported(const llm_arch arch) { if (arch == LLM_ARCH_DEEPSEEK2OCR) { return false; } - if (arch == LLM_ARCH_QWEN35_MTP || arch == LLM_ARCH_QWEN35MOE_MTP) { - return false; // MTP-only arch; requires a sibling trunk model and cannot run standalone. - } - - // FIXME some models are segfaulting with WebGPU: +// FIXME some models are segfaulting with WebGPU: #ifdef GGML_USE_WEBGPU if (arch == LLM_ARCH_QWEN3NEXT || arch == LLM_ARCH_QWEN35 || arch == LLM_ARCH_QWEN35MOE || arch == LLM_ARCH_KIMI_LINEAR) { return false; diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 51770c73b74..c0c9bf650a3 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -792,6 +792,14 @@ struct server_context_impl { } auto cparams = common_context_params_to_llama(params_dft); + + const bool spec_mtp = std::find(params_base.speculative.types.begin(), + params_base.speculative.types.end(), + COMMON_SPECULATIVE_TYPE_MTP) != params_base.speculative.types.end(); + if (spec_mtp) { + cparams.ctx_type = LLAMA_CONTEXT_TYPE_MTP; + } + ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams)); ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get()); @@ -800,36 +808,13 @@ struct server_context_impl { params_base.speculative.draft.ctx_dft = ctx_dft.get(); } else if (std::find(params_base.speculative.types.begin(), params_base.speculative.types.end(), COMMON_SPECULATIVE_TYPE_MTP) != params_base.speculative.types.end()) { - // MTP head lives in the *target* GGUF — load it as a sibling model - // with override_arch and feed it through the existing ctx_dft slot. - char trunk_arch[64] = {0}; - llama_model_meta_val_str(model_tgt, "general.architecture", trunk_arch, sizeof(trunk_arch)); - - const char * mtp_arch = nullptr; - if (std::string(trunk_arch) == "qwen35") { - mtp_arch = "qwen35_mtp"; - } else if (std::string(trunk_arch) == "qwen35moe") { - mtp_arch = "qwen35moe_mtp"; - } else { - SRV_ERR("MTP not supported for trunk architecture '%s'\n", trunk_arch); - return false; - } - - SRV_INF("loading MTP head from '%s' (override_arch=%s)\n", - params_base.model.path.c_str(), mtp_arch); - - auto mparams_mtp = common_model_params_to_llama(params_base); - mparams_mtp.override_arch = mtp_arch; - - model_dft.reset(llama_model_load_from_file(params_base.model.path.c_str(), mparams_mtp)); - if (model_dft == nullptr) { - SRV_ERR("failed to load MTP head from '%s'\n", params_base.model.path.c_str()); - return false; - } + SRV_INF("creating MTP draft context against the target model '%s'\n", + params_base.model.path.c_str()); auto cparams_mtp = common_context_params_to_llama(params_base); + cparams_mtp.ctx_type = LLAMA_CONTEXT_TYPE_MTP; - ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams_mtp)); + ctx_dft.reset(llama_init_from_model(model_tgt, cparams_mtp)); if (ctx_dft == nullptr) { SRV_ERR("%s", "failed to create MTP context\n"); return false; From 7ea1289a40975f82a739fb2f3b07428fa34aec19 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Wed, 13 May 2026 14:43:12 +0800 Subject: [PATCH 06/28] mtp -> draft-mtp --- common/arg.cpp | 8 ++++---- common/common.h | 2 +- common/speculative.cpp | 20 ++++++++++---------- tools/server/server-context.cpp | 6 +++--- 4 files changed, 18 insertions(+), 18 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 747a9e81990..8b8eb7c12bd 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -445,11 +445,11 @@ static bool parse_bool_value(const std::string & value) { // void common_params_handle_models(common_params & params, llama_example curr_ex) { - const bool spec_type_mtp = std::find(params.speculative.types.begin(), + const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(), params.speculative.types.end(), - COMMON_SPECULATIVE_TYPE_MTP) != params.speculative.types.end(); + COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params.speculative.types.end(); - auto res = common_params_handle_model(params.model, params.hf_token, params.offline, spec_type_mtp); + auto res = common_params_handle_model(params.model, params.hf_token, params.offline, spec_type_draft_mtp); if (params.no_mmproj) { params.mmproj = {}; } else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) { @@ -465,7 +465,7 @@ void common_params_handle_models(common_params & params, llama_example curr_ex) } // when --spec-type mtp is set and no draft model was provided explicitly, // fall back to the MTP head discovered alongside the -hf model - if (spec_type_mtp && res.found_mtp && + if (spec_type_draft_mtp && res.found_mtp && params.speculative.draft.mparams.path.empty() && params.speculative.draft.mparams.hf_repo.empty() && params.speculative.draft.mparams.url.empty()) { diff --git a/common/common.h b/common/common.h index 37dab50ba5c..b813cec39d6 100644 --- a/common/common.h +++ b/common/common.h @@ -159,7 +159,7 @@ enum common_speculative_type { COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE, // standalone draft model speculative decoding COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, // Eagle3 speculative decoding - COMMON_SPECULATIVE_TYPE_MTP, // multi-token prediction head loaded from the target GGUF + COMMON_SPECULATIVE_TYPE_DRAFT_MTP, // multi-token prediction head loaded from the target GGUF COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding based on n-grams COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values diff --git a/common/speculative.cpp b/common/speculative.cpp index f064009a225..fd721f2df04 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -24,7 +24,7 @@ const std::map common_speculative_type_fro {"none", COMMON_SPECULATIVE_TYPE_NONE}, {"draft-simple", COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE}, {"draft-eagle3", COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3}, - {"mtp", COMMON_SPECULATIVE_TYPE_MTP}, + {"draft-mtp", COMMON_SPECULATIVE_TYPE_DRAFT_MTP}, {"ngram-simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE}, {"ngram-map-k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K}, {"ngram-map-k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V}, @@ -366,7 +366,7 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl { } }; -struct common_speculative_state_mtp : public common_speculative_impl { +struct common_speculative_state_draft_mtp : public common_speculative_impl { common_params_speculative_draft params; // reuses the draft-model params slot (ctx_tgt/ctx_dft) llama_batch batch; @@ -383,8 +383,8 @@ struct common_speculative_state_mtp : public common_speculative_impl { std::vector i_batch_beg; std::vector i_batch_end; - common_speculative_state_mtp(const common_params_speculative & params, uint32_t n_seq) - : common_speculative_impl(COMMON_SPECULATIVE_TYPE_MTP, n_seq) + common_speculative_state_draft_mtp(const common_params_speculative & params, uint32_t n_seq) + : common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, n_seq) , params(params.draft) { auto * ctx_tgt = this->params.ctx_tgt; @@ -417,7 +417,7 @@ struct common_speculative_state_mtp : public common_speculative_impl { i_batch_end.assign(n_seq, -1); } - ~common_speculative_state_mtp() override { + ~common_speculative_state_draft_mtp() override { if (batch.token != nullptr) { free(batch.token); batch.token = nullptr; @@ -1106,7 +1106,7 @@ std::string common_speculative_type_to_str(common_speculative_type type) { case COMMON_SPECULATIVE_TYPE_NONE: return "none"; case COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE: return "draft-simple"; case COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3: return "draft-eagle3"; - case COMMON_SPECULATIVE_TYPE_MTP: return "mtp"; + case COMMON_SPECULATIVE_TYPE_DRAFT_MTP: return "draft-mtp"; case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram-simple"; case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram-map-k"; case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram-map-k4v"; @@ -1163,7 +1163,7 @@ common_speculative * common_speculative_init(common_params_speculative & params, bool has_draft_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE)); bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3 - bool has_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_MTP)) && params.draft.ctx_dft != nullptr; + bool has_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_MTP)) && params.draft.ctx_dft != nullptr; bool has_ngram_cache = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_CACHE)); bool has_ngram_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE)); @@ -1210,7 +1210,7 @@ common_speculative * common_speculative_init(common_params_speculative & params, configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, params)); } if (has_mtp) { - configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_MTP, params)); + configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, params)); } } @@ -1229,8 +1229,8 @@ common_speculative * common_speculative_init(common_params_speculative & params, impls.push_back(std::make_unique(config.params, n_seq)); break; } - case COMMON_SPECULATIVE_TYPE_MTP: { - impls.push_back(std::make_unique(config.params, n_seq)); + case COMMON_SPECULATIVE_TYPE_DRAFT_MTP: { + impls.push_back(std::make_unique(config.params, n_seq)); break; } case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: { diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index c0c9bf650a3..5b95ac05a32 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -795,7 +795,7 @@ struct server_context_impl { const bool spec_mtp = std::find(params_base.speculative.types.begin(), params_base.speculative.types.end(), - COMMON_SPECULATIVE_TYPE_MTP) != params_base.speculative.types.end(); + COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end(); if (spec_mtp) { cparams.ctx_type = LLAMA_CONTEXT_TYPE_MTP; } @@ -807,7 +807,7 @@ struct server_context_impl { params_base.speculative.draft.ctx_tgt = ctx_tgt; params_base.speculative.draft.ctx_dft = ctx_dft.get(); } else if (std::find(params_base.speculative.types.begin(), params_base.speculative.types.end(), - COMMON_SPECULATIVE_TYPE_MTP) != params_base.speculative.types.end()) { + COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end()) { SRV_INF("creating MTP draft context against the target model '%s'\n", params_base.model.path.c_str()); @@ -931,7 +931,7 @@ struct server_context_impl { slot.ctx_dft = ctx_dft.get(); slot.spec = spec.get(); slot.is_mtp_enabled = (std::find(params_base.speculative.types.begin(), params_base.speculative.types.end(), - COMMON_SPECULATIVE_TYPE_MTP) != params_base.speculative.types.end()) + COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end()) && (ctx_dft != nullptr); slot.n_ctx = n_ctx_slot; From 9243e50727b6b2a6c6c383dd671806f6b383fdb4 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Wed, 13 May 2026 15:00:30 +0800 Subject: [PATCH 07/28] remove unused llama_arch --- include/llama.h | 3 --- src/llama-model.cpp | 10 ---------- 2 files changed, 13 deletions(-) diff --git a/include/llama.h b/include/llama.h index b814e2c58de..470add516c8 100644 --- a/include/llama.h +++ b/include/llama.h @@ -315,9 +315,6 @@ extern "C" { // override key-value pairs of the model meta data const struct llama_model_kv_override * kv_overrides; - // override architecture from GGUF (e.g. load the MTP head of a Qwen3.5 GGUF as "qwen35_mtp") - const char * override_arch; - // Keep the booleans together to avoid misalignment during copy-by-value. bool vocab_only; // only load the vocabulary, no weights bool use_mmap; // use mmap if possible diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 5ab183271eb..d7b0996dd9e 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -309,15 +309,6 @@ llama_model * llama_model_create(llama_model_loader & ml, const llama_model_para if (arch == LLM_ARCH_UNKNOWN) { throw std::runtime_error("unknown model architecture: '" + ml.get_arch_name() + "'"); } - if (params.override_arch != nullptr && params.override_arch[0] != '\0') { - const llm_arch override = llm_arch_from_string(params.override_arch); - if (override == LLM_ARCH_UNKNOWN) { - throw std::runtime_error(std::string("unknown override architecture: '") + params.override_arch + "'"); - } - LLAMA_LOG_INFO("%s: overriding architecture %s -> %s\n", - __func__, llm_arch_name(arch), llm_arch_name(override)); - arch = override; - } return llama_model_create(arch, params); } @@ -2131,7 +2122,6 @@ llama_model_params llama_model_default_params() { /*.progress_callback =*/ nullptr, /*.progress_callback_user_data =*/ nullptr, /*.kv_overrides =*/ nullptr, - /*.override_arch =*/ nullptr, /*.vocab_only =*/ false, /*.use_mmap =*/ true, /*.use_direct_io =*/ false, From 23ae80ae48e5a6e5a920107797f170f8b9c4ae86 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Wed, 13 May 2026 15:09:00 +0800 Subject: [PATCH 08/28] add need_embd in speculative --- common/speculative.cpp | 45 +++++++++++++++++++++++++++++++++ common/speculative.h | 3 +++ tools/server/server-context.cpp | 16 +----------- 3 files changed, 49 insertions(+), 15 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index fd721f2df04..cddfcfd6bfc 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -145,6 +145,9 @@ struct common_speculative_impl { virtual void draft(common_speculative_draft_params_vec & dparams) = 0; virtual void accept(llama_seq_id seq_id, uint16_t n_accepted) = 0; + + // true if this implementation requires the target context to extract embeddings + virtual bool need_embd() const = 0; }; struct common_speculative_impl_draft_simple : public common_speculative_impl { @@ -340,6 +343,10 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl { void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { // noop } + + bool need_embd() const override { + return false; + } }; struct common_speculative_impl_draft_eagle3 : public common_speculative_impl { @@ -364,6 +371,10 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl { void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { // noop } + + bool need_embd() const override { + return false; + } }; struct common_speculative_state_draft_mtp : public common_speculative_impl { @@ -648,6 +659,10 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl { void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { } + + bool need_embd() const override { + return true; + } }; // state of self-speculation (simple implementation, not ngram-map) @@ -689,6 +704,10 @@ struct common_speculative_impl_ngram_simple : public common_speculative_impl { void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { // noop } + + bool need_embd() const override { + return false; + } }; struct common_speculative_impl_ngram_map_k : public common_speculative_impl { @@ -737,6 +756,10 @@ struct common_speculative_impl_ngram_map_k : public common_speculative_impl { common_ngram_map_accept(config[seq_id], n_accepted); } + + bool need_embd() const override { + return false; + } }; struct common_speculative_impl_ngram_mod : public common_speculative_impl { @@ -905,6 +928,10 @@ struct common_speculative_impl_ngram_mod : public common_speculative_impl { } } } + + bool need_embd() const override { + return false; + } }; struct common_speculative_impl_ngram_cache : public common_speculative_impl { @@ -1038,6 +1065,10 @@ struct common_speculative_impl_ngram_cache : public common_speculative_impl { void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { // noop } + + bool need_embd() const override { + return false; + } }; struct common_speculative { @@ -1333,6 +1364,20 @@ bool common_speculative_process(common_speculative * spec, const llama_batch & b return result; } +bool common_speculative_need_embd(common_speculative * spec) { + if (spec == nullptr) { + return false; + } + + for (auto & impl : spec->impls) { + if (impl->need_embd()) { + return true; + } + } + + return false; +} + void common_speculative_draft(common_speculative * spec) { if (spec == nullptr) { return; diff --git a/common/speculative.h b/common/speculative.h index 51f0b059fa4..614db9b1b50 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -53,6 +53,9 @@ void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, co // process the batch and update the internal state of the speculative context bool common_speculative_process(common_speculative * spec, const llama_batch & batch); +// true if any implementation requires target embeddings to be extracted +bool common_speculative_need_embd(common_speculative * spec); + // generate drafts for the sequences specified with `common_speculative_get_draft_params` void common_speculative_draft(common_speculative * spec); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 5b95ac05a32..649a5bf7551 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -57,11 +57,6 @@ struct server_slot { llama_context * ctx_tgt = nullptr; llama_context * ctx_dft = nullptr; - // True when this slot's speculative impl is MTP (ctx_dft is the MTP head). - // MTP needs every prefill position to carry logits=1 so the streaming - // hook in common_speculative_state_mtp::process() can read t_h_pre_norm. - bool is_mtp_enabled = false; - // multimodal mtmd_context * mctx = nullptr; @@ -243,15 +238,9 @@ struct server_slot { (ggml_time_us() - t_start) / 1000.0, n_text, (int) prompt.tokens.size()); } - bool is_mtp() const { return is_mtp_enabled; } - - // The trunk needs to emit logits at every prefill position when either: - // - the task asked for embeddings, or - // - MTP is enabled for this slot (the streaming hook in process() reads - // h_pre_norm at every prompt position). bool need_embd() const { GGML_ASSERT(task); - return task->need_embd() || is_mtp(); + return task->need_embd() || (spec && common_speculative_need_embd(spec)); } // if the context does not have a memory module then all embeddings have to be computed within a single ubatch @@ -930,9 +919,6 @@ struct server_context_impl { slot.ctx_tgt = ctx_tgt; slot.ctx_dft = ctx_dft.get(); slot.spec = spec.get(); - slot.is_mtp_enabled = (std::find(params_base.speculative.types.begin(), params_base.speculative.types.end(), - COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end()) - && (ctx_dft != nullptr); slot.n_ctx = n_ctx_slot; slot.mctx = mctx; From a5b3e988d5aba8c8d83a9759f28e342165a70cab Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sun, 26 Apr 2026 00:42:04 +0800 Subject: [PATCH 09/28] llama: allow partial seq_rm for GDN models for speculative decoding Currently speculative checkpoint needs to restart from a checkpoint after some draft tokens are not accepted, this leads to some wastage in running the target again. This PR adds the ability to rollback upto `draft_max` by storing the GDN intermediates. --- common/common.cpp | 16 +++ common/common.h | 7 +- ggml/include/ggml.h | 7 ++ ggml/src/ggml-backend-meta.cpp | 4 +- ggml/src/ggml-cpu/ggml-cpu.c | 4 +- ggml/src/ggml-cpu/ops.cpp | 43 ++++++-- ggml/src/ggml-cuda/gated_delta_net.cu | 88 +++++++++++----- ggml/src/ggml.c | 12 ++- include/llama.h | 2 + src/llama-arch.cpp | 10 ++ src/llama-arch.h | 1 + src/llama-context.cpp | 12 +++ src/llama-cparams.h | 1 + src/llama-graph.cpp | 3 +- src/llama-memory-hybrid-iswa.cpp | 2 + src/llama-memory-hybrid-iswa.h | 1 + src/llama-memory-hybrid.cpp | 2 + src/llama-memory-hybrid.h | 1 + src/llama-memory-recurrent.cpp | 47 +++++++-- src/llama-memory-recurrent.h | 8 ++ src/llama-model.cpp | 3 + src/models/delta-net-base.cpp | 144 +++++++++++++++++++++++++- src/models/models.h | 28 ++++- src/models/qwen35.cpp | 46 +------- src/models/qwen35moe.cpp | 46 +------- src/models/qwen3next.cpp | 44 +------- tests/test-backend-ops.cpp | 21 +++- tools/server/server-context.cpp | 7 +- 28 files changed, 422 insertions(+), 188 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index b701edddb3f..9a653f0f5fa 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1435,6 +1435,11 @@ common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) { goto done; } + if (llama_n_rs_seq(ctx) > 0) { + res = COMMON_CONTEXT_SEQ_RM_TYPE_PART_BOUNDED; + goto done; + } + // try to remove the last tokens if (!llama_memory_seq_rm(mem, 0, 1, -1)) { LOG_TRC("%s: the context does not support partial sequence removal\n", __func__); @@ -1505,6 +1510,17 @@ struct llama_context_params common_context_params_to_llama(const common_params & cparams.n_ctx = params.n_ctx; cparams.n_seq_max = params.n_parallel; + { + // TODO: add for MTP + bool has_spec = params.speculative.has_dft(); + for (auto t : params.speculative.types) { + if (t != COMMON_SPECULATIVE_TYPE_NONE) { + has_spec = true; + break; + } + } + cparams.n_rs_seq = has_spec ? (uint32_t) params.speculative.draft.n_max : 0u; + } cparams.n_batch = params.n_batch; cparams.n_ubatch = params.n_ubatch; cparams.n_threads = params.cpuparams.n_threads; diff --git a/common/common.h b/common/common.h index b813cec39d6..30b8933264d 100644 --- a/common/common.h +++ b/common/common.h @@ -885,9 +885,10 @@ std::string common_get_model_endpoint(); // enum common_context_seq_rm_type { - COMMON_CONTEXT_SEQ_RM_TYPE_NO = 0, // seq_rm not supported (e.g. no memory module) - COMMON_CONTEXT_SEQ_RM_TYPE_PART = 1, // can seq_rm partial sequences - COMMON_CONTEXT_SEQ_RM_TYPE_FULL = 2, // can seq_rm full sequences only + COMMON_CONTEXT_SEQ_RM_TYPE_NO = 0, // seq_rm not supported (e.g. no memory module) + COMMON_CONTEXT_SEQ_RM_TYPE_PART = 1, // can seq_rm partial sequences + COMMON_CONTEXT_SEQ_RM_TYPE_FULL = 2, // can seq_rm full sequences only + COMMON_CONTEXT_SEQ_RM_TYPE_PART_BOUNDED = 3, // can seq_rm partial sequences, bounded by n_rs_seq }; // check if the llama_context can remove sequences diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 3357a0d9985..be1b6be4cb2 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2541,6 +2541,13 @@ extern "C" { // TODO: add ggml_gated_delta_net_set_bcast() to be able to configure Q, K broadcast type: tiled vs interleaved [TAG_GGML_GDN_BCAST] // ref: https://github.com/ggml-org/llama.cpp/pull/19468#discussion_r2786394306 + // + // state is a 3D tensor of shape (S_v*S_v*H, K, n_seqs). K is the snapshot slot count: + // K == 1 → output carries the final state only. + // K > 1 → output carries K snapshot slots; the kernel writes the last min(n_tokens, K) + // per-token snapshots into the trailing slots (earlier slots are left untouched + // when n_tokens < K). + // Only slot 0 (state[:, 0, :]) is read as the initial state; the rest is shape signal. GGML_API struct ggml_tensor * ggml_gated_delta_net( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index c0ffd9a048b..6fd401caa6f 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -753,7 +753,9 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(co GGML_ASSERT(src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_1); GGML_ASSERT(src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_1); GGML_ASSERT(src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_1); - GGML_ASSERT(src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_2); + // state shape is (S_v*S_v*H, K, n_seqs); the heads dim is nested inside axis 0, + // so a head-aligned split on the input cache reshapes to axis 0 here (not axis 2). + GGML_ASSERT(src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_2 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_0); return {GGML_BACKEND_SPLIT_AXIS_0, {0}, 1}; }; diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 7b05edf6b75..cd5c61a8187 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2943,7 +2943,9 @@ struct ggml_cplan ggml_graph_plan( case GGML_OP_GATED_DELTA_NET: { const int64_t S_v = node->src[2]->ne[0]; - cur = S_v * sizeof(float) * n_tasks; + const int64_t K = node->src[5]->ne[1]; // state is (D, K, n_seqs) + const int64_t per_thread = S_v + (K > 1 ? S_v * S_v : 0); + cur = per_thread * sizeof(float) * n_tasks; } break; case GGML_OP_COUNT: { diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 6bc8dc150ce..7485ba4fc86 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -10513,19 +10513,30 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( const bool kda = (neg0 == S_v); - // scratch layout per thread: [delta(S_v)] - const int64_t scratch_per_thread = S_v; + // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. + const int64_t K = src_state->ne[1]; + GGML_ASSERT(K >= 1); + // per-seq stride in floats (slot 0 of seq s lives at state + s * seq_stride) + const int64_t state_seq_stride = src_state->nb[2] / sizeof(float); + + const int64_t per_thread = S_v + (K > 1 ? S_v * S_v : 0); const int ith = params->ith; - float * delta = (float *)params->wdata + ith * scratch_per_thread + CACHE_LINE_SIZE_F32; + float * delta = (float *)params->wdata + ith * per_thread + CACHE_LINE_SIZE_F32; + float * state_work = K > 1 ? (delta + S_v) : nullptr; // output layout: [attn_scores | new_states] - // attn_scores: S_v * H * n_tokens * n_seqs floats - // new_states: S_v * S_v * H * n_seqs floats - const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; + // attn_scores: S_v * H * n_tokens * n_seqs floats + // new_states: S_v * S_v * H * n_seqs * K floats (K snapshot slots; last min(n_tokens, K)) + const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; + const int64_t state_size_per_snap = S_v * S_v * H * n_seqs; float * attn_out_base = (float *)dst->data; float * state_out_base = (float *)dst->data + attn_score_elems; + // snapshot slot mapping: target_slot = t - shift. When n_tokens < K only the last + // n_tokens slots are written; earlier slots are left untouched (caller-owned). + const int64_t shift = n_tokens - K; + const float * state_in_base = (const float *)src_state->data; //const int64_t rq1 = nev1 / neq1; @@ -10545,10 +10556,15 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( const int64_t iq3 = iv3 / rq3; const int64_t ik3 = iv3 / rk3; - float * s_out = state_out_base + (iv3 * H + iv1) * S_v * S_v; + // For K=1, write directly to the single output slot to avoid an extra memcpy at the end. + // For K>1, work in scratch and copy out per-token when the slot is in range. + float * s_out = (K > 1) + ? state_work + : state_out_base + (iv3 * H + iv1) * S_v * S_v; - // copy input state into output buffer and operate in-place - const float * s_in = state_in_base + (iv3 * H + iv1) * S_v * S_v; + // copy input state into the working buffer and operate in-place + // state layout (D, K, n_seqs): slot 0 of seq iv3 starts at iv3 * state_seq_stride. + const float * s_in = state_in_base + iv3 * state_seq_stride + iv1 * S_v * S_v; memcpy(s_out, s_in, S_v * S_v * sizeof(float)); // attn output pointer for first token of this (head, seq) @@ -10598,6 +10614,15 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( } attn_data += S_v * H; // advance to next token + + if (K > 1) { + const int64_t target_slot = t - shift; + if (target_slot >= 0 && target_slot < K) { + float * curr_state_o = state_out_base + target_slot * state_size_per_snap + + (iv3 * H + iv1) * S_v * S_v; + memcpy(curr_state_o, s_out, S_v * S_v * sizeof(float)); + } + } } } } diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index 6b44bec7317..ef4bb29c8ad 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -1,6 +1,6 @@ #include "gated_delta_net.cuh" -template +template __global__ void __launch_bounds__((ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v) * 4, 2) gated_delta_net_cuda(const float * q, const float * k, @@ -23,7 +23,8 @@ gated_delta_net_cuda(const float * q, int64_t sb3, const uint3 neqk1_magic, const uint3 rq3_magic, - float scale) { + float scale, + int K) { const uint32_t h_idx = blockIdx.x; const uint32_t sequence = blockIdx.y; // each warp owns one column, using warp-level primitives to reduce across rows @@ -37,9 +38,13 @@ gated_delta_net_cuda(const float * q, float * attn_data = dst; float * state = dst + attn_score_elems; - const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v; - state += state_offset; - curr_state += state_offset + col * S_v; + // input state layout (D, K, n_seqs) — seq stride is K * D = K * H * S_v * S_v. + // output state layout (per-slot D * n_seqs) — same per-(seq,head) offset as before. + const int64_t state_in_offset = sequence * K * H * S_v * S_v + h_idx * S_v * S_v; + const int64_t state_out_offset = (sequence * H + h_idx) * S_v * S_v; + const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output + state += state_out_offset; + curr_state += state_in_offset + col * S_v; attn_data += (sequence * n_tokens * H + h_idx) * S_v; constexpr int warp_size = ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v; @@ -54,6 +59,10 @@ gated_delta_net_cuda(const float * q, s_shard[r] = curr_state[i]; } + // slot mapping: target_slot = t - shift. When n_tokens < K only the last n_tokens slots + // are written; earlier slots are left untouched (caller-owned). + const int shift = (int) n_tokens - K; + for (int t = 0; t < n_tokens; t++) { const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1; const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1; @@ -135,17 +144,30 @@ gated_delta_net_cuda(const float * q, } attn_data += S_v * H; + + if constexpr (keep_intermediates_t) { + const int target_slot = t - shift; + if (target_slot >= 0 && target_slot < K) { + float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset; +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + curr_state[col * S_v + i] = s_shard[r]; + } + } + } } - // Write state back to global memory (transposed layout) + if constexpr (!keep_intermediates_t) { #pragma unroll - for (int r = 0; r < rows_per_lane; r++) { - const int i = r * warp_size + lane; - state[col * S_v + i] = s_shard[r]; + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + state[col * S_v + i] = s_shard[r]; + } } } -template +template static void launch_gated_delta_net( const float * q_d, const float * k_d, const float * v_d, const float * g_d, const float * b_d, const float * s_d, @@ -155,7 +177,7 @@ static void launch_gated_delta_net( int64_t sv1, int64_t sv2, int64_t sv3, int64_t sb1, int64_t sb2, int64_t sb3, int64_t neqk1, int64_t rq3, - float scale, cudaStream_t stream) { + float scale, int K, cudaStream_t stream) { //TODO: Add chunked kernel for even faster pre-fill const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; const int num_warps = 4; @@ -169,29 +191,29 @@ static void launch_gated_delta_net( switch (S_v) { case 16: - gated_delta_net_cuda<16, KDA><<>>( + gated_delta_net_cuda<16, KDA, keep_intermediates_t><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; case 32: - gated_delta_net_cuda<32, KDA><<>>( + gated_delta_net_cuda<32, KDA, keep_intermediates_t><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; case 64: { - gated_delta_net_cuda<64, KDA><<>>( + gated_delta_net_cuda<64, KDA, keep_intermediates_t><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; } case 128: { - gated_delta_net_cuda<128, KDA><<>>( + gated_delta_net_cuda<128, KDA, keep_intermediates_t><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; } default: @@ -261,13 +283,29 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * cudaStream_t stream = ctx.stream(); + // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. + const int K = (int) src_state->ne[1]; + const bool keep_intermediates = K > 1; + if (kda) { - launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, - S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1, rq3, scale, stream); + if (keep_intermediates) { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } else { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } } else { - launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, - S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1, rq3, scale, stream); + if (keep_intermediates) { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } else { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } } } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 191cf2fa106..476c3079795 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -6210,11 +6210,13 @@ struct ggml_tensor * ggml_gated_delta_net( GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v); GGML_ASSERT(beta->ne[0] == 1); - GGML_ASSERT(ggml_nelements(state) == S_v * S_v * H * n_seqs); - - // concat output and new_state into a single tensor - // output: S_v * H * n_tokens * n_seqs, state: S_v * S_v * H * n_seqs - const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + S_v * n_seqs, 1, 1 }; + // state is a 3D tensor (S_v*S_v*H, K, n_seqs). K is the snapshot slot count. + GGML_ASSERT(state->ne[0] == S_v * S_v * H); + GGML_ASSERT(state->ne[2] == n_seqs); + GGML_ASSERT(state->ne[3] == 1); + const int64_t K = state->ne[1]; + const int64_t state_rows = K * S_v * n_seqs; + const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + state_rows, 1, 1 }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); result->op = GGML_OP_GATED_DELTA_NET; diff --git a/include/llama.h b/include/llama.h index 470add516c8..0f711ab29e3 100644 --- a/include/llama.h +++ b/include/llama.h @@ -338,6 +338,7 @@ extern "C" { uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode uint32_t n_ubatch; // physical maximum batch size uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models) + uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback (0 = no rollback) [EXPERIMENTAL] int32_t n_threads; // number of threads to use for generation int32_t n_threads_batch; // number of threads to use for batch processing @@ -536,6 +537,7 @@ extern "C" { LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx); + LLAMA_API uint32_t llama_n_rs_seq (const struct llama_context * ctx); DEPRECATED(LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model), "use llama_model_n_ctx_train instead"); DEPRECATED(LLAMA_API int32_t llama_n_embd (const struct llama_model * model), "use llama_model_n_embd instead"); diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index ab4334da79b..c95f341b07d 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -878,6 +878,16 @@ bool llm_arch_is_diffusion(const llm_arch & arch) { } } +bool llm_arch_supports_recurrent_partial_rollback(const llm_arch & arch) { + switch (arch) { + case LLM_ARCH_QWEN35: + case LLM_ARCH_QWEN35MOE: + return true; + default: + return false; + } +} + bool llm_arch_supports_sm_tensor(const llm_arch & arch) { switch (arch) { case LLM_ARCH_GROK: diff --git a/src/llama-arch.h b/src/llama-arch.h index e37d548c98e..e3765dba3da 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -637,3 +637,4 @@ bool llm_arch_is_recurrent (const llm_arch & arch); bool llm_arch_is_hybrid (const llm_arch & arch); bool llm_arch_is_diffusion (const llm_arch & arch); bool llm_arch_supports_sm_tensor(const llm_arch & arch); +bool llm_arch_supports_recurrent_partial_rollback(const llm_arch & arch); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 6ecbe1b6083..4a000f59f1c 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -51,6 +51,13 @@ llama_context::llama_context( throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ)); } + cparams.n_rs_seq = params.n_rs_seq; + if (cparams.n_rs_seq > 0 && !llm_arch_supports_recurrent_partial_rollback(model.arch)) { + LLAMA_LOG_DEBUG("%s: n_rs_seq=%u requested but model arch does not support recurrent partial rollback; clamping to 0\n", + __func__, cparams.n_rs_seq); + cparams.n_rs_seq = 0; + } + cparams.n_threads = params.n_threads; cparams.n_threads_batch = params.n_threads_batch; cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor; @@ -3291,6 +3298,7 @@ llama_context_params llama_context_default_params() { /*.n_batch =*/ 2048, /*.n_ubatch =*/ 512, /*.n_seq_max =*/ 1, + /*.n_rs_seq =*/ 0, /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, /*.ctx_type =*/ LLAMA_CONTEXT_TYPE_DEFAULT, @@ -3445,6 +3453,10 @@ uint32_t llama_n_seq_max(const llama_context * ctx) { return ctx->n_seq_max(); } +uint32_t llama_n_rs_seq(const llama_context * ctx) { + return ctx->get_cparams().n_rs_seq; +} + const llama_model * llama_get_model(const llama_context * ctx) { return &ctx->get_model(); } diff --git a/src/llama-cparams.h b/src/llama-cparams.h index cbf74eba63e..5898a1c38d5 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -12,6 +12,7 @@ struct llama_cparams { uint32_t n_batch; uint32_t n_ubatch; uint32_t n_seq_max; + uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback int32_t n_threads; // number of threads to use for generation int32_t n_threads_batch; // number of threads to use for batch processing diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index fe155c92dea..858c297dd76 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -2528,7 +2528,8 @@ ggml_tensor * llm_graph_context::build_rs( int32_t rs_zero, const llm_graph_get_rows_fn & get_state_rows) const { - ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, rs_size); + GGML_UNUSED(rs_size); + ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, s->ne[1]); // Clear a single state which will then be copied to the other cleared states. // Note that this is a no-op when the view is zero-sized. diff --git a/src/llama-memory-hybrid-iswa.cpp b/src/llama-memory-hybrid-iswa.cpp index 10e6b459797..a59561ea54d 100644 --- a/src/llama-memory-hybrid-iswa.cpp +++ b/src/llama-memory-hybrid-iswa.cpp @@ -24,6 +24,7 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa( uint32_t rs_size, /* common */ uint32_t n_seq_max, + uint32_t n_rs_seq, bool offload, bool unified, /* layer filters */ @@ -54,6 +55,7 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa( offload, rs_size, n_seq_max, + n_rs_seq, filter_recr == nullptr ? [&](int32_t il) { return hparams.is_recurrent(il); } : filter_recr diff --git a/src/llama-memory-hybrid-iswa.h b/src/llama-memory-hybrid-iswa.h index 807c8aac96c..c9d3f9f57c5 100644 --- a/src/llama-memory-hybrid-iswa.h +++ b/src/llama-memory-hybrid-iswa.h @@ -34,6 +34,7 @@ class llama_memory_hybrid_iswa : public llama_memory_i { uint32_t rs_size, /* common */ uint32_t n_seq_max, + uint32_t n_rs_seq, bool offload, bool unified, /* layer filters */ diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index 4ce1af592c1..fd305cab79c 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -24,6 +24,7 @@ llama_memory_hybrid::llama_memory_hybrid( uint32_t rs_size, /* common */ uint32_t n_seq_max, + uint32_t n_rs_seq, bool offload, bool unified, /* layer filters */ @@ -54,6 +55,7 @@ llama_memory_hybrid::llama_memory_hybrid( offload, rs_size, n_seq_max, + n_rs_seq, filter_recr == nullptr ? [&](int32_t il) { return hparams.is_recurrent(il); } : filter_recr diff --git a/src/llama-memory-hybrid.h b/src/llama-memory-hybrid.h index 558cafdf984..484eafb7499 100644 --- a/src/llama-memory-hybrid.h +++ b/src/llama-memory-hybrid.h @@ -34,6 +34,7 @@ class llama_memory_hybrid : public llama_memory_i { uint32_t rs_size, /* common */ uint32_t n_seq_max, + uint32_t n_rs_seq, bool offload, bool unified, /* layer filters */ diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index c07f1d969cb..1913e9414a0 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -24,6 +24,7 @@ llama_memory_recurrent::llama_memory_recurrent( bool offload, uint32_t mem_size, uint32_t n_seq_max, + uint32_t n_rs_seq, const layer_filter_cb & filter) : hparams(model.hparams), n_seq_max(n_seq_max) { const int32_t n_layer = hparams.n_layer; @@ -31,6 +32,9 @@ llama_memory_recurrent::llama_memory_recurrent( size = mem_size; used = 0; + this->n_rs_seq = n_rs_seq; + rs_idx.assign(n_seq_max, 0); + cells.clear(); cells.resize(mem_size); @@ -92,8 +96,9 @@ llama_memory_recurrent::llama_memory_recurrent( throw std::runtime_error("failed to create ggml context for rs cache"); } - ggml_tensor * r = ggml_new_tensor_2d(ctx, type_r, hparams.n_embd_r(), mem_size); - ggml_tensor * s = ggml_new_tensor_2d(ctx, type_s, hparams.n_embd_s(), mem_size); + const uint32_t n_rows = mem_size * (1 + n_rs_seq); + ggml_tensor * r = ggml_new_tensor_2d(ctx, type_r, hparams.n_embd_r(), n_rows); + ggml_tensor * s = ggml_new_tensor_2d(ctx, type_s, hparams.n_embd_s(), n_rows); ggml_format_name(r, "cache_r_l%d", i); ggml_format_name(s, "cache_s_l%d", i); r_l[i] = r; @@ -141,7 +146,6 @@ void llama_memory_recurrent::clear(bool data) { } bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - //printf("[DEBUG] calling llama_memory_recurrent::seq_rm` with `seq_id=%d, p0=%d, p1=%d`\n", seq_id, p0, p1); uint32_t new_head = size; if (p0 < 0) { @@ -161,10 +165,16 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos if (0 <= seq_id) { int32_t & tail_id = cells[seq_id].tail; if (tail_id >= 0) { - const auto & cell = cells[tail_id]; - // partial intersection is invalid if it includes the final pos + auto & cell = cells[tail_id]; + + // partial rollback via per-token snapshot index (bounded by n_rs_seq) if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) { - //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false, p0 = %d, cell.pos = %d, p1 = %d\n", p0, cell.pos, p1); + const llama_pos rollback = cell.pos - (p0 - 1); + if (rollback >= 1 && rollback <= (llama_pos) n_rs_seq) { + set_rs_idx(seq_id, (uint32_t) rollback); + cell.pos = p0 - 1; + return true; + } return false; } // invalidate tails which will be cleared @@ -368,6 +378,13 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const { return result; } +void llama_memory_recurrent::set_rs_idx(llama_seq_id seq_id, uint32_t idx) { + if (seq_id < 0 || (size_t) seq_id >= rs_idx.size()) { + return; + } + rs_idx[seq_id] = (idx > n_rs_seq) ? n_rs_seq : idx; +} + std::map llama_memory_recurrent::memory_breakdown() const { std::map ret; for (const auto & [_, buf] : ctxs_bufs) { @@ -1163,5 +1180,21 @@ ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const { } int32_t llama_memory_recurrent_context::s_copy(int i) const { - return mem->cells[i + mem->head].src0; + const uint32_t cell_idx = i + mem->head; + const int32_t src0 = mem->cells[cell_idx].src0; + + if (mem->n_rs_seq == 0) { + return src0; + } + + uint32_t idx = 0; + if (!mem->cells[cell_idx].seq_id.empty()) { + const llama_seq_id seq = *mem->cells[cell_idx].seq_id.begin(); + if (seq >= 0 && (size_t) seq < mem->rs_idx.size()) { + idx = mem->rs_idx[seq]; + // reset rollback idx + mem->rs_idx[seq] = 0; + } + } + return (int32_t)(idx * mem->size) + src0; } diff --git a/src/llama-memory-recurrent.h b/src/llama-memory-recurrent.h index 47f01d73912..29c58afc9c2 100644 --- a/src/llama-memory-recurrent.h +++ b/src/llama-memory-recurrent.h @@ -23,6 +23,7 @@ class llama_memory_recurrent : public llama_memory_i { bool offload, uint32_t mem_size, uint32_t n_seq_max, + uint32_t n_rs_seq, const layer_filter_cb & filter); ~llama_memory_recurrent() = default; @@ -69,6 +70,13 @@ class llama_memory_recurrent : public llama_memory_i { uint32_t size = 0; // total number of cells, shared across all sequences uint32_t used = 0; // used cells (i.e. at least one seq_id) + // number of recurrent-state snapshots per seq for rollback; tensors are widened to (1 + n_rs_seq) groups + uint32_t n_rs_seq = 0; + // per-seq rollback index + std::vector rs_idx; + + void set_rs_idx(llama_seq_id seq_id, uint32_t idx); + // computed before each graph build uint32_t n = 0; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index d7b0996dd9e..8bf20a716eb 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1961,6 +1961,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, cparams.offload_kqv, std::max((uint32_t) 1, cparams.n_seq_max), cparams.n_seq_max, + cparams.n_rs_seq, nullptr); } else if (llm_arch_is_hybrid(arch) && !mtp_on_hybrid_qwen35) { // The main difference between hybrid architectures is the @@ -2002,6 +2003,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* recurrent_type_s */ GGML_TYPE_F32, /* recurrent_rs_size */ std::max((uint32_t) 1, cparams.n_seq_max), /* n_seq_max */ cparams.n_seq_max, + /* n_rs_seq */ cparams.n_rs_seq, /* offload */ cparams.offload_kqv, /* unified */ cparams.kv_unified, /* filter_attn */ std::move(filter_attn), @@ -2020,6 +2022,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* recurrent_type_v */ GGML_TYPE_F32, /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max), /* n_seq_max */ cparams.n_seq_max, + /* n_rs_seq */ cparams.n_rs_seq, /* offload */ cparams.offload_kqv, /* unified */ cparams.kv_unified, /* filter_attn */ std::move(filter_attn), diff --git a/src/models/delta-net-base.cpp b/src/models/delta-net-base.cpp index 6bc989c9509..081f490c5a9 100644 --- a/src/models/delta-net-base.cpp +++ b/src/models/delta-net-base.cpp @@ -1,6 +1,7 @@ #include "models.h" #include "llama-impl.h" +#include "llama-memory-recurrent.h" // utility to get one slice from the third dimension // input dim: [x, y, c, b] @@ -397,7 +398,9 @@ std::pair llm_build_delta_net_base::build_delta_ne GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); - ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s); + // K=1 (final state only): reshape to 3D (S_v*S_v*H_v, 1, n_seqs) for ggml_gated_delta_net. + ggml_tensor * s_3d = ggml_reshape_3d(ctx0, s, S_v * S_v * H_v, 1, n_seqs); + ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s_3d); if (n_tokens == 1) { cb(result, LLAMA_TENSOR_NAME_FGDN_AR, il); } else { @@ -443,3 +446,142 @@ std::pair llm_build_delta_net_base::build_delta_ne return build_delta_net_chunking(q, k, v, g, b, s, il); } + +bool llm_build_delta_net_base::keep_intermediates() const { + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + return cparams.n_rs_seq > 0 + && n_seq_tokens > 1 + && (uint32_t) n_seq_tokens <= 1 + cparams.n_rs_seq; +} + +ggml_tensor * llm_build_delta_net_base::build_conv_state( + llm_graph_input_rs * inp, + ggml_tensor * conv_states_all, + ggml_tensor * qkv_mixed, + int64_t conv_kernel_size, + int64_t conv_channels, + int il) { + const auto * mctx_cur = inp->mctx; + const auto kv_head = mctx_cur->get_head(); + const uint32_t mem_size = mctx_cur->get_size(); + const int64_t n_seqs = ubatch.n_seqs; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + const bool keep = keep_intermediates(); + + ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); + cb(conv_states, "conv_states", il); + + conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); + cb(conv_states, "conv_states_reshaped", il); + + qkv_mixed = ggml_transpose(ctx0, qkv_mixed); + cb(qkv_mixed, "qkv_mixed_transposed", il); + + ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); + cb(conv_input, "conv_input", il); + + if (!keep) { + ggml_tensor * last_conv_states = + ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], + conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input)); + cb(last_conv_states, "last_conv_states", il); + + ggml_tensor * state_update_target = + ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1], + kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); + cb(state_update_target, "state_update_target", il); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); + } else { + const int64_t row_count = (conv_kernel_size - 1) * conv_channels; + const size_t row_size = row_count * ggml_element_size(conv_states_all); + for (int64_t t = 1; t <= n_seq_tokens; ++t) { + const uint32_t slot = (uint32_t)(n_seq_tokens - t); + ggml_tensor * src = + ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, + conv_input->nb[1], conv_input->nb[2], + t * ggml_element_size(conv_input)); + ggml_tensor * dst = + ggml_view_2d(ctx0, conv_states_all, row_count, n_seqs, + conv_states_all->nb[1], + ((size_t) slot * mem_size + kv_head) * row_size); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, src, dst)); + } + } + + return conv_input; +} + +ggml_tensor * llm_build_delta_net_base::build_recurrent_attn( + llm_graph_input_rs * inp, + ggml_tensor * ssm_states_all, + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il) { + const auto * mctx_cur = inp->mctx; + const auto kv_head = mctx_cur->get_head(); + const uint32_t mem_size = mctx_cur->get_size(); + + const int64_t S_v = s->ne[0]; + const int64_t H_v = s->ne[2]; + const int64_t n_seqs = s->ne[3]; + const int64_t n_seq_tokens = q->ne[2]; + + if (!keep_intermediates()) { + auto attn_out = build_delta_net(q, k, v, g, b, s, il); + ggml_tensor * output = attn_out.first; + ggml_tensor * new_state = attn_out.second; + cb(output, "attn_output", il); + cb(new_state, "new_state", il); + + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, new_state, + ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], + kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + + return output; + } + + const int64_t D = S_v * S_v * H_v; + const int64_t K = (int64_t) cparams.n_rs_seq; + + ggml_tensor * state_3d = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, D, K, n_seqs); + ggml_tensor * slot_0 = ggml_view_2d(ctx0, state_3d, D, n_seqs, state_3d->nb[2], 0); + ggml_tensor * state_in_2d = ggml_reshape_2d(ctx0, s, D, n_seqs); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, state_in_2d, slot_0)); + + ggml_tensor * gdn_out = ggml_gated_delta_net(ctx0, q, k, v, g, b, state_3d); + cb(gdn_out, LLAMA_TENSOR_NAME_FGDN_CH, il); + + const int64_t attn_score_elems = S_v * H_v * n_seq_tokens * n_seqs; + const int64_t state_size_per_snap = S_v * S_v * H_v * n_seqs; + + ggml_tensor * output = ggml_view_4d(ctx0, gdn_out, + S_v, H_v, n_seq_tokens, n_seqs, + ggml_row_size(gdn_out->type, S_v), + ggml_row_size(gdn_out->type, S_v * H_v), + ggml_row_size(gdn_out->type, S_v * H_v * n_seq_tokens), + 0); + cb(output, "attn_output", il); + + const size_t row_size = hparams.n_embd_s() * ggml_element_size(ssm_states_all); + for (int64_t k_i = 0; k_i < K; ++k_i) { + const uint32_t cache_slot = (uint32_t) (K - 1 - k_i); + ggml_tensor * src = ggml_view_4d(ctx0, gdn_out, + S_v, S_v, H_v, n_seqs, + ggml_row_size(gdn_out->type, S_v), + ggml_row_size(gdn_out->type, S_v * S_v), + ggml_row_size(gdn_out->type, S_v * S_v * H_v), + ggml_row_size(gdn_out->type, attn_score_elems + k_i * state_size_per_snap)); + ggml_tensor * dst = ggml_view_2d(ctx0, ssm_states_all, + hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], + ((size_t) cache_slot * mem_size + kv_head) * row_size); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, src, dst)); + } + + return output; +} diff --git a/src/models/models.h b/src/models/models.h index fe95b9b89ad..8bef479754a 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -46,7 +46,7 @@ struct llm_build_delta_net_base : public llm_graph_context { ggml_tensor * s, int il); - // use the ggml_gated_delta_net fused operator + // use the ggml_gated_delta_net fused operator (K=1; state has shape (D, 1, n_seqs)) std::pair build_delta_net_fused( ggml_tensor * q, ggml_tensor * k, @@ -65,6 +65,32 @@ struct llm_build_delta_net_base : public llm_graph_context { ggml_tensor * b, ggml_tensor * s, int il); + + // true when speculative rollback is enabled and the batch fits in the rs cache + bool keep_intermediates() const; + + // read conv state from cache, concat with qkv_mixed, write back (single slot or per-token) + // qkv_mixed: (qkv_dim, n_seq_tokens, n_seqs); returns conv_input: (kernel_size + n_seq_tokens - 1, channels, n_seqs) + ggml_tensor * build_conv_state( + llm_graph_input_rs * inp, + ggml_tensor * conv_states_all, + ggml_tensor * qkv_mixed, + int64_t conv_kernel_size, + int64_t conv_channels, + int il); + + // run delta-net attention and write the new recurrent state(s) back to ssm_states_all + // s: (head_v_dim, head_v_dim, num_v_heads, n_seqs); returns output: (head_v_dim, num_v_heads, n_seq_tokens, n_seqs) + ggml_tensor * build_recurrent_attn( + llm_graph_input_rs * inp, + ggml_tensor * ssm_states_all, + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il); }; struct llm_build_rwkv6_base : public llm_graph_context { diff --git a/src/models/qwen35.cpp b/src/models/qwen35.cpp index e59d7f28856..2b4d5b14cd4 100644 --- a/src/models/qwen35.cpp +++ b/src/models/qwen35.cpp @@ -348,8 +348,6 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear( const int64_t head_v_dim = d_inner / num_v_heads; const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const auto kv_head = mctx_cur->get_head(); - GGML_ASSERT(n_seqs != 0); GGML_ASSERT(ubatch.equal_seqs()); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); @@ -379,41 +377,14 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear( gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs); - // Get convolution states from cache ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - // Build the convolution states tensor - ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); - cb(conv_states, "conv_states", il); - - // Calculate convolution kernel size ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; const int64_t conv_kernel_size = conv_kernel->ne[0]; const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; - conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); - cb(conv_states, "conv_states_reshaped", il); - - qkv_mixed = ggml_transpose(ctx0, qkv_mixed); - cb(qkv_mixed, "qkv_mixed_transposed", il); - - ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); - cb(conv_input, "conv_input", il); - - // Update convolution state cache - // Extract the last (conv_kernel_size - 1) states from conv_input - ggml_tensor * last_conv_states = - ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], - conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input)); - cb(last_conv_states, "last_conv_states", il); - - ggml_tensor * state_update_target = - ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1], - kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); - cb(state_update_target, "state_update_target", il); - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); + ggml_tensor * conv_input = build_conv_state(inp, conv_states_all, qkv_mixed, conv_kernel_size, conv_channels, il); ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); @@ -464,7 +435,7 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear( //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); // if head keys and value keys are different, repeat to force tensors into matching shapes - // note: need explicit repeat only if we are not using the fused GDN + // note: need explicit repeat only if we are not using the fused GDN. if (num_k_heads != num_v_heads && (!cparams.fused_gdn_ar || !cparams.fused_gdn_ch)) { GGML_ASSERT(num_v_heads % num_k_heads == 0); q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); @@ -475,18 +446,7 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il); - - ggml_tensor * output = attn_out.first; - ggml_tensor * new_state = attn_out.second; - cb(output, "attn_output", il); - cb(new_state, "new_state", il); - - // Update the recurrent states - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, new_state, - ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], - kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + ggml_tensor * output = build_recurrent_attn(inp, ssm_states_all, q_conv, k_conv, v_conv, gate, beta, state, il); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); diff --git a/src/models/qwen35moe.cpp b/src/models/qwen35moe.cpp index b11cdaa6edd..22e3e110765 100644 --- a/src/models/qwen35moe.cpp +++ b/src/models/qwen35moe.cpp @@ -371,8 +371,6 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear( const int64_t head_v_dim = d_inner / num_v_heads; const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const auto kv_head = mctx_cur->get_head(); - GGML_ASSERT(n_seqs != 0); GGML_ASSERT(ubatch.equal_seqs()); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); @@ -402,41 +400,14 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear( gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs); - // Get convolution states from cache ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - // Build the convolution states tensor - ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); - cb(conv_states, "conv_states", il); - - // Calculate convolution kernel size ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; const int64_t conv_kernel_size = conv_kernel->ne[0]; const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; - conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); - cb(conv_states, "conv_states_reshaped", il); - - qkv_mixed = ggml_transpose(ctx0, qkv_mixed); - cb(qkv_mixed, "qkv_mixed_transposed", il); - - ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); - cb(conv_input, "conv_input", il); - - // Update convolution state cache - // Extract the last (conv_kernel_size - 1) states from conv_input - ggml_tensor * last_conv_states = - ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], - conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input)); - cb(last_conv_states, "last_conv_states", il); - - ggml_tensor * state_update_target = - ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1], - kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); - cb(state_update_target, "state_update_target", il); - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); + ggml_tensor * conv_input = build_conv_state(inp, conv_states_all, qkv_mixed, conv_kernel_size, conv_channels, il); ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); @@ -487,7 +458,7 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear( //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); // if head keys and value keys are different, repeat to force tensors into matching shapes - // note: need explicit repeat only if we are not using the fused GDN + // note: need explicit repeat only if we are not using the fused GDN. if (num_k_heads != num_v_heads && (!cparams.fused_gdn_ar || !cparams.fused_gdn_ch)) { GGML_ASSERT(num_v_heads % num_k_heads == 0); q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); @@ -498,18 +469,7 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il); - - ggml_tensor * output = attn_out.first; - ggml_tensor * new_state = attn_out.second; - cb(output, "attn_output", il); - cb(new_state, "new_state", il); - - // Update the recurrent states - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, new_state, - ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], - kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + ggml_tensor * output = build_recurrent_attn(inp, ssm_states_all, q_conv, k_conv, v_conv, gate, beta, state, il); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); diff --git a/src/models/qwen3next.cpp b/src/models/qwen3next.cpp index bdc3026c1de..1d873427db5 100644 --- a/src/models/qwen3next.cpp +++ b/src/models/qwen3next.cpp @@ -378,8 +378,6 @@ ggml_tensor * llama_model_qwen3next::graph::build_layer_attn_linear( const int64_t head_v_dim = d_inner / num_v_heads; const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const auto kv_head = mctx_cur->get_head(); - GGML_ASSERT(n_seqs != 0); GGML_ASSERT(ubatch.equal_seqs()); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); @@ -429,41 +427,14 @@ ggml_tensor * llama_model_qwen3next::graph::build_layer_attn_linear( beta = ggml_reshape_4d(ctx0, beta, 1, num_v_heads, n_seq_tokens, n_seqs); gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs); - // Get convolution states from cache ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - // Build the convolution states tensor - ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); - cb(conv_states, "conv_states", il); - - // Calculate convolution kernel size ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; const int64_t conv_kernel_size = conv_kernel->ne[0]; const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; - conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); - cb(conv_states, "conv_states_reshaped", il); - - qkv_mixed = ggml_transpose(ctx0, qkv_mixed); - cb(qkv_mixed, "qkv_mixed_transposed", il); - - ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); - cb(conv_input, "conv_input", il); - - // Update convolution state cache - // Extract the last (conv_kernel_size - 1) states from conv_input - ggml_tensor * last_conv_states = - ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], - conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input)); - cb(last_conv_states, "last_conv_states", il); - - ggml_tensor * state_update_target = - ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1], - kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); - cb(state_update_target, "state_update_target", il); - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); + ggml_tensor * conv_input = build_conv_state(inp, conv_states_all, qkv_mixed, conv_kernel_size, conv_channels, il); ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); @@ -540,18 +511,7 @@ ggml_tensor * llama_model_qwen3next::graph::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il); - - ggml_tensor * output = attn_out.first; - ggml_tensor * new_state = attn_out.second; - cb(output, "attn_output", il); - cb(new_state, "new_state", il); - - // Update the recurrent states - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, new_state, - ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], - kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + ggml_tensor * output = build_recurrent_attn(inp, ssm_states_all, q_conv, k_conv, v_conv, gate, beta, state, il); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 8a561c038a1..76f7cb5a867 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3832,16 +3832,17 @@ struct test_gated_delta_net : public test_case { const int v_repeat; const bool permuted; const bool kda; + const int64_t K; // snapshot slot count: 1 = final-only, >1 = last K states std::string vars() override { - return VARS_TO_STR8(type, head_count, head_size, n_seq_tokens, n_seqs, v_repeat, permuted, kda); + return VARS_TO_STR9(type, head_count, head_size, n_seq_tokens, n_seqs, v_repeat, permuted, kda, K); } test_gated_delta_net(ggml_type type = GGML_TYPE_F32, int64_t head_count = 4, int64_t head_size = 16, int64_t n_seq_tokens = 1, int64_t n_seqs = 1, - int v_repeat = 1, bool permuted = false, bool kda = false) + int v_repeat = 1, bool permuted = false, bool kda = false, int64_t K = 1) : type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs), - v_repeat(v_repeat), permuted(permuted), kda(kda) {} + v_repeat(v_repeat), permuted(permuted), kda(kda), K(K) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * q; @@ -3863,7 +3864,7 @@ struct test_gated_delta_net : public test_case { const int64_t g_ne0 = kda ? head_size : 1; ggml_tensor * g = ggml_new_tensor_4d(ctx, type, g_ne0, head_count * v_repeat, n_seq_tokens, n_seqs); ggml_tensor * beta = ggml_new_tensor_4d(ctx, type, 1, head_count * v_repeat, n_seq_tokens, n_seqs); - ggml_tensor * state = ggml_new_tensor_2d(ctx, type, head_size * v_repeat * head_size * head_count, n_seqs); + ggml_tensor * state = ggml_new_tensor_3d(ctx, type, head_size * v_repeat * head_size * head_count, K, n_seqs); ggml_set_name(g, "g"); ggml_set_name(beta, "beta"); ggml_set_name(state, "state"); @@ -9034,6 +9035,18 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 33, 1, 1, false, true)); test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 100, 1, 1, false, true)); + // K > 1: output keeps the last min(n_tokens, K) per-token snapshots in the trailing K-token region. + // exact-match cases (K == n_seq_tokens): + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 16, 2, 1, 1, false, false, /*K=*/2)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 32, 4, 1, 1, false, false, /*K=*/4)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2, 1, false, false, /*K=*/4)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 8, 128, 4, 1, 1, false, false, /*K=*/4)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2, 1, false, true, /*K=*/4)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 8, 32, 4, 2, 2, false, true, /*K=*/4)); + // overflow: n_tokens > K — only the last K snapshots kept. + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 32, 8, 1, 1, false, false, /*K=*/3)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 16, 2, 1, false, false, /*K=*/4)); + #if 0 // these tests are disabled to save execution time, sbut they can be handy for debugging test_cases.emplace_back(new test_llama(2, true)); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 649a5bf7551..0f89aa38ea3 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -789,6 +789,7 @@ struct server_context_impl { cparams.ctx_type = LLAMA_CONTEXT_TYPE_MTP; } + cparams.n_rs_seq = 0; ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams)); ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get()); @@ -2735,9 +2736,11 @@ struct server_context_impl { // checkpoints are created only if: // - the model does not support partial sequence removal // - the model uses SWA (and we are not using `swa_full`) + // - the model supports partial sequence removal but only up to a fixed bound do_checkpoint = do_checkpoint && ( - (ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) || - (n_swa > 0)); + ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL || + ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_PART_BOUNDED || + n_swa > 0); bool has_mtmd = false; From 3aa9ddc6005c6c36d00a0b64c5b89bc075727006 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 14 May 2026 14:32:42 +0800 Subject: [PATCH 10/28] fix pending state --- common/common.cpp | 18 +++--- common/speculative.cpp | 64 +++++++++++++------ .../speculative-simple/speculative-simple.cpp | 3 +- include/llama.h | 2 +- src/llama-context.cpp | 1 + src/llama-memory-recurrent.cpp | 4 +- src/models/delta-net-base.cpp | 5 +- 7 files changed, 64 insertions(+), 33 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 9a653f0f5fa..12acd5bf910 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1436,6 +1436,7 @@ common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) { } if (llama_n_rs_seq(ctx) > 0) { + LOG_INF("%s: the context supports bounded partial sequence removal\n", __func__); res = COMMON_CONTEXT_SEQ_RM_TYPE_PART_BOUNDED; goto done; } @@ -1511,15 +1512,14 @@ struct llama_context_params common_context_params_to_llama(const common_params & cparams.n_ctx = params.n_ctx; cparams.n_seq_max = params.n_parallel; { - // TODO: add for MTP - bool has_spec = params.speculative.has_dft(); - for (auto t : params.speculative.types) { - if (t != COMMON_SPECULATIVE_TYPE_NONE) { - has_spec = true; - break; - } - } - cparams.n_rs_seq = has_spec ? (uint32_t) params.speculative.draft.n_max : 0u; + // Since MTP has a low number of draft tokens, enable recurrent checkpointing + // for hybrid attn models + // TODO: figure out how to make it place nicely with other speculative techniques + bool has_mtp = std::any_of(params.speculative.types.begin(), + params.speculative.types.end(), [&](auto t) { + return t == COMMON_SPECULATIVE_TYPE_DRAFT_MTP; + }); + cparams.n_rs_seq = has_mtp ? (uint32_t) params.speculative.draft.n_max : 0u; } cparams.n_batch = params.n_batch; cparams.n_ubatch = params.n_ubatch; diff --git a/common/speculative.cpp b/common/speculative.cpp index cddfcfd6bfc..0a78bc1136c 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -394,6 +394,16 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl { std::vector i_batch_beg; std::vector i_batch_end; + // Hidden rows from the most recent target verification batch, grouped by seq. + // Row 0 corresponds to the sampled token, row N to the Nth accepted draft token. + std::vector> verify_h; + std::vector verify_h_rows; + + // Per-seq draft length from the last draft() call, used in accept() to + // roll back ctx_dft's recurrent state past the AR draft's redundant + // pre-advancement before process() mirrored the verify batch. + std::vector last_n_drafted; + common_speculative_state_draft_mtp(const common_params_speculative & params, uint32_t n_seq) : common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, n_seq) , params(params.draft) @@ -414,7 +424,7 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl { for (auto & s : smpls) { common_params_sampling sparams; sparams.no_perf = false; - sparams.top_k = 10; + sparams.top_k = 1; sparams.samplers = { COMMON_SAMPLER_TYPE_TOP_K }; s.reset(common_sampler_init(llama_get_model(ctx_dft), sparams)); } @@ -426,6 +436,11 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl { i_batch_beg.assign(n_seq, -1); i_batch_end.assign(n_seq, -1); + + verify_h.assign(n_seq, {}); + verify_h_rows.assign(n_seq, 0); + + last_n_drafted.assign(n_seq, 0); } ~common_speculative_state_draft_mtp() override { @@ -535,8 +550,17 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl { continue; } - const float * h_last = llama_get_embeddings_pre_norm_ith(ctx_tgt, i_batch_end[seq_id]); - std::memcpy(pending_h[seq_id].data(), h_last, row_bytes); + const int32_t n_rows = i_batch_end[seq_id] - i_batch_beg[seq_id] + 1; + verify_h_rows[seq_id] = n_rows; + verify_h[seq_id].resize((size_t) n_rows * n_embd); + + for (int32_t i = 0; i < n_rows; ++i) { + const float * h = llama_get_embeddings_pre_norm_ith(ctx_tgt, i_batch_beg[seq_id] + i); + std::memcpy(verify_h[seq_id].data() + (size_t) i * n_embd, h, row_bytes); + } + + std::memcpy(pending_h[seq_id].data(), + verify_h[seq_id].data() + (size_t) (n_rows - 1) * n_embd, row_bytes); } return true; @@ -606,14 +630,6 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl { // add drafted token for each sequence const llama_token id = cur_p->data[0].id; - // only collect very high-confidence draft tokens - if (cur_p->data[0].p < params.p_min) { - drafting[seq_id] = false; - n_drafting--; - - continue; - } - common_sampler_accept(smpl, id, true); auto & dp = dparams.at(seq_id); @@ -621,8 +637,7 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl { result.push_back(id); - if ((params.n_max <= (int) result.size()) || - (dp.n_max > 0 && dp.n_max <= (int) result.size())) { + if (params.n_max <= (int) result.size()) { drafting[seq_id] = false; n_drafting--; continue; @@ -646,7 +661,8 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl { ++i; } - for (auto & dp : dparams) { + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + auto & dp = dparams[seq_id]; if (!dp.drafting) { continue; } @@ -654,10 +670,24 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl { if (dp.result->size() < (size_t) params.n_min) { dp.result->clear(); } + + last_n_drafted[seq_id] = (uint16_t) dp.result->size(); } } - void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { + void accept(llama_seq_id seq_id, uint16_t n_accepted) override { + if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq) { + return; + } + + const int32_t n_rows = verify_h_rows[seq_id]; + if (n_rows <= 0) { + return; + } + + const int32_t i_h = std::min(n_accepted, n_rows - 1); + const size_t row_bytes = (size_t) n_embd * sizeof(float); + std::memcpy(pending_h[seq_id].data(), verify_h[seq_id].data() + (size_t) i_h * n_embd, row_bytes); } bool need_embd() const override { @@ -1460,10 +1490,6 @@ void common_speculative_draft(common_speculative * spec) { } void common_speculative_accept(common_speculative * spec, llama_seq_id seq_id, uint16_t n_accepted) { - if (n_accepted == 0) { - return; - } - common_speculative_impl * impl = spec->impl_last[seq_id]; GGML_ASSERT(impl); diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index 5325bcc9e3f..6848de988d0 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -82,7 +82,8 @@ int main(int argc, char ** argv) { } // check if the context supports partial sequence removal - const bool use_ckpt_tgt = (common_context_can_seq_rm(ctx_tgt) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL); + const bool use_ckpt_tgt = (common_context_can_seq_rm(ctx_tgt) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) + || (common_context_can_seq_rm(ctx_tgt) == COMMON_CONTEXT_SEQ_RM_TYPE_PART_BOUNDED); const bool use_ckpt_dft = (common_context_can_seq_rm(ctx_dft.get()) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL); if (use_ckpt_tgt) { diff --git a/include/llama.h b/include/llama.h index 0f711ab29e3..75095b22d08 100644 --- a/include/llama.h +++ b/include/llama.h @@ -537,7 +537,7 @@ extern "C" { LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx); - LLAMA_API uint32_t llama_n_rs_seq (const struct llama_context * ctx); + LLAMA_API uint32_t llama_n_rs_seq (const struct llama_context * ctx); DEPRECATED(LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model), "use llama_model_n_ctx_train instead"); DEPRECATED(LLAMA_API int32_t llama_n_embd (const struct llama_model * model), "use llama_model_n_embd instead"); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 4a000f59f1c..a481757e388 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -225,6 +225,7 @@ llama_context::llama_context( LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false"); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); + LLAMA_LOG_INFO("%s: n_rs_seq = %u\n", __func__, cparams.n_rs_seq); if (cparams.n_ctx_seq < hparams.n_ctx_train) { LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n", diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 1913e9414a0..084c5d9ea4f 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -120,8 +120,8 @@ llama_memory_recurrent::llama_memory_recurrent( const size_t memory_size_r = size_r_bytes(); const size_t memory_size_s = size_s_bytes(); - LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__, - (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), mem_size, n_layer, n_seq_max, + LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs %2u rs_seq), R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__, + (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), mem_size, n_layer, n_seq_max, n_rs_seq, ggml_type_name(type_r), (float)memory_size_r / (1024.0f * 1024.0f), ggml_type_name(type_s), (float)memory_size_s / (1024.0f * 1024.0f)); } diff --git a/src/models/delta-net-base.cpp b/src/models/delta-net-base.cpp index 081f490c5a9..852d5bc88ef 100644 --- a/src/models/delta-net-base.cpp +++ b/src/models/delta-net-base.cpp @@ -547,7 +547,10 @@ ggml_tensor * llm_build_delta_net_base::build_recurrent_attn( } const int64_t D = S_v * S_v * H_v; - const int64_t K = (int64_t) cparams.n_rs_seq; + // Memory has 1 + n_rs_seq slots (slot 0 = current, slots 1..n_rs_seq = rollback distances). + // The snapshot buffer must match — otherwise the deepest rollback (= n_rs_seq) reads + // uninitialized memory and corrupts the recurrent state. + const int64_t K = (int64_t) cparams.n_rs_seq + 1; ggml_tensor * state_3d = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, D, K, n_seqs); ggml_tensor * slot_0 = ggml_view_2d(ctx0, state_3d, D, n_seqs, state_3d->nb[2], 0); From 2ef737a063c2fd63509d4c12a79a0f0f31802251 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 14 May 2026 08:41:23 +0200 Subject: [PATCH 11/28] vulkan: add GDN partial rollback --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 8 ++++- .../vulkan-shaders/gated_delta_net.comp | 29 ++++++++++++++++--- 2 files changed, 32 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 8c4cf9ef1db..d29a4bab2e2 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1506,6 +1506,7 @@ struct vk_op_gated_delta_net_push_constants { uint32_t sb1, sb2, sb3; uint32_t neq1, rq3; float scale; + uint32_t K; }; struct vk_op_ssm_scan_push_constants { @@ -10767,6 +10768,7 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s const ggml_tensor * src_q = dst->src[0]; const ggml_tensor * src_v = dst->src[2]; const ggml_tensor * src_beta = dst->src[4]; + const ggml_tensor * src_state = dst->src[5]; GGML_ASSERT(dst->buffer != nullptr); @@ -10775,6 +10777,9 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s const uint32_t n_tokens = (uint32_t)src_v->ne[2]; const uint32_t n_seqs = (uint32_t)src_v->ne[3]; + // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. + const uint32_t K = (uint32_t)src_state->ne[1]; + const uint32_t s_off = S_v * H * n_tokens * n_seqs; vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op); @@ -10808,7 +10813,8 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s sv1, sv2, sv3, sb1, sb2, sb3, neq1, rq3, - scale + scale, + K }; ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp index 5e9f8308c1d..33c3202dbb7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp @@ -31,6 +31,7 @@ layout(push_constant) uniform Parameters { uint sb1, sb2, sb3; uint neq1, rq3; float scale; + uint K; }; layout(binding = 0) readonly buffer QBuf { FLOAT_TYPE data_q[]; }; @@ -101,13 +102,21 @@ void main() { const uint iq3 = seq_id / rq3; const uint state_size = S_V * S_V; - const uint state_base = (seq_id * H + head_id) * state_size; + // input state layout (D, K, n_seqs): per-seq stride is K*H*D; we read slot 0. + const uint state_in_base = (seq_id * K * H + head_id) * state_size; + // output state layout per slot: same per-(seq,head) offset as the single-slot case. + const uint state_out_base = (seq_id * H + head_id) * state_size; + const uint state_size_per_snap = state_size * H * n_seqs; FLOAT_TYPE s_shard[ROWS_PER_LANE]; [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { - s_shard[r] = FLOAT_TYPE(data_state[state_base + col * S_V + r * LANES_PER_COLUMN + lane]); + s_shard[r] = FLOAT_TYPE(data_state[state_in_base + col * S_V + r * LANES_PER_COLUMN + lane]); } + // snapshot slot mapping: target_slot = t - shift. When n_tokens < K, only the last + // n_tokens slots are written; earlier slots are left untouched (caller-owned). + const int shift = int(n_tokens) - int(K); + uint attn_off = (seq_id * n_tokens * H + head_id) * S_V; for (uint t = 0; t < n_tokens; t++) { @@ -161,9 +170,21 @@ void main() { } attn_off += S_V * H; + + if (K > 1u) { + const int target_slot = int(t) - shift; + if (target_slot >= 0 && target_slot < int(K)) { + const uint slot_base = s_off + uint(target_slot) * state_size_per_snap + state_out_base; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + data_dst[slot_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r]; + } + } + } } - [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { - data_dst[s_off + state_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r]; + if (K == 1u) { + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + data_dst[s_off + state_out_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r]; + } } } From d7443da5cf0ed9532607e0a2ca7d97806959a038 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 14 May 2026 15:55:53 +0800 Subject: [PATCH 12/28] meta: extend check to axis 1 --- ggml/src/ggml-backend-meta.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index 6fd401caa6f..df0f405ed9f 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -755,7 +755,7 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(co GGML_ASSERT(src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_1); // state shape is (S_v*S_v*H, K, n_seqs); the heads dim is nested inside axis 0, // so a head-aligned split on the input cache reshapes to axis 0 here (not axis 2). - GGML_ASSERT(src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_2 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_0); + GGML_ASSERT(src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_2 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_1 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_0); return {GGML_BACKEND_SPLIT_AXIS_0, {0}, 1}; }; @@ -2142,4 +2142,3 @@ ggml_backend_t ggml_backend_meta_simple_backend(ggml_backend_t meta_backend, siz const ggml_backend_meta_context * backend_ctx = (const ggml_backend_meta_context *) meta_backend->context; return backend_ctx->backend_configs[index].backend; } - From 19be81cce9ee43e118333be7e1aafb23bf850f39 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 14 May 2026 10:24:09 +0300 Subject: [PATCH 13/28] metal: add GDN partial rollback Extend the gated delta net kernel to store intermediate states for partial rollback support on the Metal backend. - Add K (snapshot slot count) as a function constant - Read input state from slot 0 of the 3D state tensor - Write intermediate states to different slots during token loop - For K=1, maintain backward-compatible single-slot behavior Ref: https://github.com/ggml-org/llama.cpp/commit/8c05923630110223669f069af2000e9cf10c02bc Assisted-by: llama.cpp:local pi --- ggml/src/ggml-metal/ggml-metal-device.cpp | 5 ++- ggml/src/ggml-metal/ggml-metal.metal | 46 ++++++++++++++++++----- 2 files changed, 41 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index f0147af84c1..e288a27f992 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -590,6 +590,8 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net( const int ne20 = op->src[2]->ne[0]; // S_v const int ne21 = op->src[2]->ne[1]; // H const int ne30 = op->src[3]->ne[0]; // G + // state is src[5], 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. + const int K = op->src[5]->ne[1]; const int nsg = op->src[2]->ne[0]/32; @@ -598,7 +600,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net( GGML_ASSERT(ne20 % 32 == 0); snprintf(base, 256, "kernel_gated_delta_net_%s_%d", ggml_type_name(op->src[0]->type), nsg); - snprintf(name, 256, "%s_ne20=%d_ne30=%d", base, ne20, ne30); + snprintf(name, 256, "%s_ne20=%d_ne30=%d_K=%d", base, ne20, ne30, K); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { @@ -606,6 +608,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net( ggml_metal_cv_set_int16(cv, ne20, FC_GATED_DELTA_NET + 0); ggml_metal_cv_set_int16(cv, ne30, FC_GATED_DELTA_NET + 1); + ggml_metal_cv_set_int16(cv, K, FC_GATED_DELTA_NET + 2); res = ggml_metal_library_compile_pipeline(lib, base, name, cv); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 3882b955847..deb616105ae 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2531,6 +2531,7 @@ kernel void kernel_rwkv_wkv7_f32( constant short FC_gated_delta_net_ne20 [[function_constant(FC_GATED_DELTA_NET + 0)]]; constant short FC_gated_delta_net_ne30 [[function_constant(FC_GATED_DELTA_NET + 1)]]; +constant short FC_gated_delta_net_K [[function_constant(FC_GATED_DELTA_NET + 2)]]; #if 1 template @@ -2552,17 +2553,21 @@ kernel void kernel_gated_delta_net_impl( const uint tx = tpitg.x; const uint ty = tpitg.y; - const uint i23 = tgpig.z; // B - const uint i21 = tgpig.y; // H - const uint i20 = tgpig.x*NSG + ty; + const uint i23 = tgpig.z; // B (n_seqs) + const uint i21 = tgpig.y; // H (head) + const uint i20 = tgpig.x*NSG + ty; // row within S_v const uint i01 = i21 % args.ne01; const uint i11 = i21 % args.ne11; const float scale = 1.0f / sqrt((float)S_v); + const uint K = FC_gated_delta_net_K; + + // input state layout (D, K, n_seqs): per-seq stride is K*H*D; we read slot 0. // state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous - device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v; + const uint state_in_base = (i23*K*args.ne21 + i21)*S_v*S_v + i20*S_v; + device const float * s_ptr = (device const float *) (s) + state_in_base; float ls[NSG]; @@ -2580,6 +2585,17 @@ kernel void kernel_gated_delta_net_impl( device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21); device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G; + // snapshot slot mapping: target_slot = t - shift. When n_tokens < K, only the last + // n_tokens slots are written; earlier slots are left untouched (caller-owned). + const int shift = (int)args.ne22 - (int)K; + + // output state base offset: after attention scores + const uint attn_size = args.ne22 * args.ne21 * S_v * args.ne23; + // output state per-slot size: S_v * S_v * H * n_seqs + const uint state_size_per_snap = S_v * S_v * args.ne21 * args.ne23; + // per-(seq,head) offset within a slot + const uint state_out_base = (i23*args.ne21 + i21)*S_v*S_v + i20*S_v; + for (short t = 0; t < args.ne22; t++) { float s_k = 0.0f; @@ -2627,13 +2643,25 @@ kernel void kernel_gated_delta_net_impl( b_ptr += args.ne21; g_ptr += args.ne21*G; - } - device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v; + if (K > 1u) { + const int target_slot = (int)t - shift; + if (target_slot >= 0 && target_slot < (int)K) { + device float * dst_state = (device float *) (dst) + attn_size + (uint)target_slot * state_size_per_snap + state_out_base; + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + dst_state[is] = ls[j]; + } + } + } + } - FOR_UNROLL (short j = 0; j < NSG; j++) { - const short is = tx*NSG + j; - dst_state[is] = ls[j]; + if (K == 1u) { + device float * dst_state = (device float *) (dst) + attn_size + state_out_base; + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + dst_state[is] = ls[j]; + } } #undef S_v From d0759f0cd739066c1140e1b1c4b84b8343ef2de8 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 14 May 2026 10:59:33 +0200 Subject: [PATCH 14/28] delta_net_base: use ggml_pad instead of new_tensor --- src/models/delta-net-base.cpp | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/models/delta-net-base.cpp b/src/models/delta-net-base.cpp index 852d5bc88ef..e261ac54e20 100644 --- a/src/models/delta-net-base.cpp +++ b/src/models/delta-net-base.cpp @@ -547,15 +547,11 @@ ggml_tensor * llm_build_delta_net_base::build_recurrent_attn( } const int64_t D = S_v * S_v * H_v; - // Memory has 1 + n_rs_seq slots (slot 0 = current, slots 1..n_rs_seq = rollback distances). - // The snapshot buffer must match — otherwise the deepest rollback (= n_rs_seq) reads - // uninitialized memory and corrupts the recurrent state. const int64_t K = (int64_t) cparams.n_rs_seq + 1; - ggml_tensor * state_3d = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, D, K, n_seqs); - ggml_tensor * slot_0 = ggml_view_2d(ctx0, state_3d, D, n_seqs, state_3d->nb[2], 0); - ggml_tensor * state_in_2d = ggml_reshape_2d(ctx0, s, D, n_seqs); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, state_in_2d, slot_0)); + // TODO: remove pad + simplify + ggml_tensor * state_in_3d = ggml_reshape_3d(ctx0, s, D, 1, n_seqs); + ggml_tensor * state_3d = ggml_pad(ctx0, state_in_3d, 0, K - 1, 0, 0); ggml_tensor * gdn_out = ggml_gated_delta_net(ctx0, q, k, v, g, b, state_3d); cb(gdn_out, LLAMA_TENSOR_NAME_FGDN_CH, il); From 78a78ae0d520404c62ce7328b748d7be04bae3f1 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 14 May 2026 17:20:50 +0800 Subject: [PATCH 15/28] review: add need_rs_seq --- common/common.cpp | 11 +---------- common/common.h | 9 +++++++++ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 12acd5bf910..60cd6a48ab8 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1511,16 +1511,7 @@ struct llama_context_params common_context_params_to_llama(const common_params & cparams.n_ctx = params.n_ctx; cparams.n_seq_max = params.n_parallel; - { - // Since MTP has a low number of draft tokens, enable recurrent checkpointing - // for hybrid attn models - // TODO: figure out how to make it place nicely with other speculative techniques - bool has_mtp = std::any_of(params.speculative.types.begin(), - params.speculative.types.end(), [&](auto t) { - return t == COMMON_SPECULATIVE_TYPE_DRAFT_MTP; - }); - cparams.n_rs_seq = has_mtp ? (uint32_t) params.speculative.draft.n_max : 0u; - } + cparams.n_rs_seq = params.speculative.need_n_rs_seq(); cparams.n_batch = params.n_batch; cparams.n_ubatch = params.n_ubatch; cparams.n_threads = params.cpuparams.n_threads; diff --git a/common/common.h b/common/common.h index 30b8933264d..190e521c600 100644 --- a/common/common.h +++ b/common/common.h @@ -13,6 +13,7 @@ #include #include #include +#include #if defined(_WIN32) && !defined(_WIN32_WINNT) #define _WIN32_WINNT 0x0A00 @@ -356,6 +357,14 @@ struct common_params_speculative { bool has_dft() const { return !draft.mparams.path.empty() || !draft.mparams.hf_repo.empty(); } + + uint32_t need_n_rs_seq() const { + bool needs_rs_seq = std::any_of(types.begin(), types.end(), [&](auto t) { + return t == COMMON_SPECULATIVE_TYPE_DRAFT_MTP; + }); + + return needs_rs_seq ? draft.n_max : 0u; + } }; struct common_params_vocoder { From 611f422db2308bf456e1066dc1e93b19d90cf42b Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 14 May 2026 17:22:52 +0800 Subject: [PATCH 16/28] review: rename part_bounded to n_rs --- common/common.cpp | 2 +- common/common.h | 4 ++-- examples/speculative-simple/speculative-simple.cpp | 1 - tools/server/server-context.cpp | 4 ++-- 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 60cd6a48ab8..ee94274ca62 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1437,7 +1437,7 @@ common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) { if (llama_n_rs_seq(ctx) > 0) { LOG_INF("%s: the context supports bounded partial sequence removal\n", __func__); - res = COMMON_CONTEXT_SEQ_RM_TYPE_PART_BOUNDED; + res = COMMON_CONTEXT_SEQ_RM_TYPE_RS; goto done; } diff --git a/common/common.h b/common/common.h index 190e521c600..97977d580ad 100644 --- a/common/common.h +++ b/common/common.h @@ -160,7 +160,7 @@ enum common_speculative_type { COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE, // standalone draft model speculative decoding COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, // Eagle3 speculative decoding - COMMON_SPECULATIVE_TYPE_DRAFT_MTP, // multi-token prediction head loaded from the target GGUF + COMMON_SPECULATIVE_TYPE_DRAFT_MTP, // Multi-token prediction COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding based on n-grams COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values @@ -897,7 +897,7 @@ enum common_context_seq_rm_type { COMMON_CONTEXT_SEQ_RM_TYPE_NO = 0, // seq_rm not supported (e.g. no memory module) COMMON_CONTEXT_SEQ_RM_TYPE_PART = 1, // can seq_rm partial sequences COMMON_CONTEXT_SEQ_RM_TYPE_FULL = 2, // can seq_rm full sequences only - COMMON_CONTEXT_SEQ_RM_TYPE_PART_BOUNDED = 3, // can seq_rm partial sequences, bounded by n_rs_seq + COMMON_CONTEXT_SEQ_RM_TYPE_RS = 3, // can seq_rm partial sequences, bounded by n_rs_seq }; // check if the llama_context can remove sequences diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index 6848de988d0..868fd145fdc 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -83,7 +83,6 @@ int main(int argc, char ** argv) { // check if the context supports partial sequence removal const bool use_ckpt_tgt = (common_context_can_seq_rm(ctx_tgt) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) - || (common_context_can_seq_rm(ctx_tgt) == COMMON_CONTEXT_SEQ_RM_TYPE_PART_BOUNDED); const bool use_ckpt_dft = (common_context_can_seq_rm(ctx_dft.get()) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL); if (use_ckpt_tgt) { diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 0f89aa38ea3..f94da7d6442 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -2738,8 +2738,8 @@ struct server_context_impl { // - the model uses SWA (and we are not using `swa_full`) // - the model supports partial sequence removal but only up to a fixed bound do_checkpoint = do_checkpoint && ( - ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL || - ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_PART_BOUNDED || + ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL || + ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_RS || n_swa > 0); bool has_mtmd = false; From df4cd32f9042a662955251789d4e92c12d42ae0a Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 14 May 2026 17:24:34 +0800 Subject: [PATCH 17/28] review: deslop comments --- ggml/include/ggml.h | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index be1b6be4cb2..41566d41aef 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2542,12 +2542,10 @@ extern "C" { // TODO: add ggml_gated_delta_net_set_bcast() to be able to configure Q, K broadcast type: tiled vs interleaved [TAG_GGML_GDN_BCAST] // ref: https://github.com/ggml-org/llama.cpp/pull/19468#discussion_r2786394306 // - // state is a 3D tensor of shape (S_v*S_v*H, K, n_seqs). K is the snapshot slot count: - // K == 1 → output carries the final state only. - // K > 1 → output carries K snapshot slots; the kernel writes the last min(n_tokens, K) - // per-token snapshots into the trailing slots (earlier slots are left untouched - // when n_tokens < K). - // Only slot 0 (state[:, 0, :]) is read as the initial state; the rest is shape signal. + // state is a 3D tensor of shape (S_v*S_v*H, K, n_seqs): + // K == 1: output carries the final state only. + // K > 1: output carries K snapshot slots; the kernel writes the last min(n_tokens, K) + // per-token snapshots into the trailing slots GGML_API struct ggml_tensor * ggml_gated_delta_net( struct ggml_context * ctx, struct ggml_tensor * q, From 9674711d163f88da37fc07f30ae2d3995d26935e Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 14 May 2026 17:40:55 +0800 Subject: [PATCH 18/28] review: rename, add asserts --- .../speculative-simple/speculative-simple.cpp | 2 +- ggml/src/ggml-cuda/gated_delta_net.cu | 22 +++++++++---------- src/llama-arch.cpp | 2 +- src/llama-arch.h | 2 +- src/llama-context.cpp | 2 +- src/llama-memory-recurrent.cpp | 10 ++++----- src/models/delta-net-base.cpp | 6 ++--- src/models/models.h | 2 +- 8 files changed, 23 insertions(+), 25 deletions(-) diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index 868fd145fdc..5325bcc9e3f 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -82,7 +82,7 @@ int main(int argc, char ** argv) { } // check if the context supports partial sequence removal - const bool use_ckpt_tgt = (common_context_can_seq_rm(ctx_tgt) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) + const bool use_ckpt_tgt = (common_context_can_seq_rm(ctx_tgt) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL); const bool use_ckpt_dft = (common_context_can_seq_rm(ctx_dft.get()) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL); if (use_ckpt_tgt) { diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index ef4bb29c8ad..b4c9845e7a7 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -1,6 +1,6 @@ #include "gated_delta_net.cuh" -template +template __global__ void __launch_bounds__((ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v) * 4, 2) gated_delta_net_cuda(const float * q, const float * k, @@ -145,7 +145,7 @@ gated_delta_net_cuda(const float * q, attn_data += S_v * H; - if constexpr (keep_intermediates_t) { + if constexpr (keep_rs_t) { const int target_slot = t - shift; if (target_slot >= 0 && target_slot < K) { float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset; @@ -158,7 +158,7 @@ gated_delta_net_cuda(const float * q, } } - if constexpr (!keep_intermediates_t) { + if constexpr (!keep_rs_t) { #pragma unroll for (int r = 0; r < rows_per_lane; r++) { const int i = r * warp_size + lane; @@ -167,7 +167,7 @@ gated_delta_net_cuda(const float * q, } } -template +template static void launch_gated_delta_net( const float * q_d, const float * k_d, const float * v_d, const float * g_d, const float * b_d, const float * s_d, @@ -191,26 +191,26 @@ static void launch_gated_delta_net( switch (S_v) { case 16: - gated_delta_net_cuda<16, KDA, keep_intermediates_t><<>>( + gated_delta_net_cuda<16, KDA, keep_rs_t><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; case 32: - gated_delta_net_cuda<32, KDA, keep_intermediates_t><<>>( + gated_delta_net_cuda<32, KDA, keep_rs_t><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; case 64: { - gated_delta_net_cuda<64, KDA, keep_intermediates_t><<>>( + gated_delta_net_cuda<64, KDA, keep_rs_t><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; } case 128: { - gated_delta_net_cuda<128, KDA, keep_intermediates_t><<>>( + gated_delta_net_cuda<128, KDA, keep_rs_t><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); @@ -285,10 +285,10 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. const int K = (int) src_state->ne[1]; - const bool keep_intermediates = K > 1; + const bool keep_rs = K > 1; if (kda) { - if (keep_intermediates) { + if (keep_rs) { launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1, rq3, scale, K, stream); @@ -298,7 +298,7 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * sb1, sb2, sb3, neqk1, rq3, scale, K, stream); } } else { - if (keep_intermediates) { + if (keep_rs) { launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1, rq3, scale, K, stream); diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index c95f341b07d..4bee6fbe651 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -878,7 +878,7 @@ bool llm_arch_is_diffusion(const llm_arch & arch) { } } -bool llm_arch_supports_recurrent_partial_rollback(const llm_arch & arch) { +bool llm_arch_supports_rs_rollback(const llm_arch & arch) { switch (arch) { case LLM_ARCH_QWEN35: case LLM_ARCH_QWEN35MOE: diff --git a/src/llama-arch.h b/src/llama-arch.h index e3765dba3da..89cf16cc37c 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -637,4 +637,4 @@ bool llm_arch_is_recurrent (const llm_arch & arch); bool llm_arch_is_hybrid (const llm_arch & arch); bool llm_arch_is_diffusion (const llm_arch & arch); bool llm_arch_supports_sm_tensor(const llm_arch & arch); -bool llm_arch_supports_recurrent_partial_rollback(const llm_arch & arch); +bool llm_arch_supports_rs_rollback(const llm_arch & arch); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index a481757e388..d62abc4009b 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -52,7 +52,7 @@ llama_context::llama_context( } cparams.n_rs_seq = params.n_rs_seq; - if (cparams.n_rs_seq > 0 && !llm_arch_supports_recurrent_partial_rollback(model.arch)) { + if (cparams.n_rs_seq > 0 && !llm_arch_supports_rs_rollback(model.arch)) { LLAMA_LOG_DEBUG("%s: n_rs_seq=%u requested but model arch does not support recurrent partial rollback; clamping to 0\n", __func__, cparams.n_rs_seq); cparams.n_rs_seq = 0; diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 084c5d9ea4f..109a77be404 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -170,12 +170,10 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos // partial rollback via per-token snapshot index (bounded by n_rs_seq) if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) { const llama_pos rollback = cell.pos - (p0 - 1); - if (rollback >= 1 && rollback <= (llama_pos) n_rs_seq) { - set_rs_idx(seq_id, (uint32_t) rollback); - cell.pos = p0 - 1; - return true; - } - return false; + GGML_ASSERT(rollback >= 1 && rollback <= (llama_pos) n_rs_seq); + set_rs_idx(seq_id, (uint32_t) rollback); + cell.pos = p0 - 1; + return true; } // invalidate tails which will be cleared if (p0 <= cell.pos && cell.pos < p1) { diff --git a/src/models/delta-net-base.cpp b/src/models/delta-net-base.cpp index e261ac54e20..2a4e00384e9 100644 --- a/src/models/delta-net-base.cpp +++ b/src/models/delta-net-base.cpp @@ -447,7 +447,7 @@ std::pair llm_build_delta_net_base::build_delta_ne return build_delta_net_chunking(q, k, v, g, b, s, il); } -bool llm_build_delta_net_base::keep_intermediates() const { +bool llm_build_delta_net_base::keep_rs() const { const int64_t n_seq_tokens = ubatch.n_seq_tokens; return cparams.n_rs_seq > 0 && n_seq_tokens > 1 @@ -466,7 +466,7 @@ ggml_tensor * llm_build_delta_net_base::build_conv_state( const uint32_t mem_size = mctx_cur->get_size(); const int64_t n_seqs = ubatch.n_seqs; const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const bool keep = keep_intermediates(); + const bool keep = keep_rs(); ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); cb(conv_states, "conv_states", il); @@ -531,7 +531,7 @@ ggml_tensor * llm_build_delta_net_base::build_recurrent_attn( const int64_t n_seqs = s->ne[3]; const int64_t n_seq_tokens = q->ne[2]; - if (!keep_intermediates()) { + if (!keep_rs()) { auto attn_out = build_delta_net(q, k, v, g, b, s, il); ggml_tensor * output = attn_out.first; ggml_tensor * new_state = attn_out.second; diff --git a/src/models/models.h b/src/models/models.h index 8bef479754a..4e40536a5ea 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -67,7 +67,7 @@ struct llm_build_delta_net_base : public llm_graph_context { int il); // true when speculative rollback is enabled and the batch fits in the rs cache - bool keep_intermediates() const; + bool keep_rs() const; // read conv state from cache, concat with qkv_mixed, write back (single slot or per-token) // qkv_mixed: (qkv_dim, n_seq_tokens, n_seqs); returns conv_input: (kernel_size + n_seq_tokens - 1, channels, n_seqs) From 7b54ac5d802668301b2703b424bfc31b91636ec9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 14 May 2026 18:19:26 +0300 Subject: [PATCH 19/28] server : adjust checkpoint logic (#11) * server : adjust checkpoint logic * cont : rm asserts --- common/arg.cpp | 7 +-- common/common.cpp | 25 +++++++++ common/common.h | 7 +++ src/llama-memory-recurrent.cpp | 10 ++-- tools/server/server-context.cpp | 91 ++++++++++++++++++--------------- 5 files changed, 92 insertions(+), 48 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 8b8eb7c12bd..84b3c8f962d 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3629,8 +3629,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("comma-separated list of types of speculative decoding to use (default: %s)\n", common_speculative_type_name_str(params.speculative.types).c_str()), [](common_params & params, const std::string & value) { - const auto enabled_types = string_split(value, ','); - params.speculative.types = common_speculative_types_from_names(enabled_types); + const auto types_str = string_split(value, ','); + auto types = common_speculative_types_from_names(types_str); + params.speculative.types.insert(params.speculative.types.end(), types.begin(), types.end()); } ).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_TYPE")); add_opt(common_arg( @@ -4119,7 +4120,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--spec-default"}, string_format("enable default speculative decoding config"), [](common_params & params) { - params.speculative.types = { COMMON_SPECULATIVE_TYPE_NGRAM_MOD }; + params.speculative.types.push_back(COMMON_SPECULATIVE_TYPE_NGRAM_MOD); params.speculative.ngram_mod.n_match = 24; params.speculative.ngram_mod.n_min = 48; params.speculative.ngram_mod.n_max = 64; diff --git a/common/common.cpp b/common/common.cpp index ee94274ca62..9f597b4feb8 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1455,6 +1455,23 @@ common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) { return res; } +void common_context_seq_rm(llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + auto * mem = llama_get_memory(ctx); + if (!llama_memory_seq_rm(mem, seq_id, p0, p1)) { + GGML_ABORT("%s", string_format("failed to remove sequence %d with p0=%d, p1=%d\n", seq_id, p0, p1).c_str()); + } +} + +void common_context_seq_cp(llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + auto * mem = llama_get_memory(ctx); + llama_memory_seq_cp(mem, seq_id_src, seq_id_dst, p0, p1); +} + +void common_context_seq_add(llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { + auto * mem = llama_get_memory(ctx); + llama_memory_seq_add(mem, seq_id, p0, p1, delta); +} + void common_set_adapter_lora(struct llama_context * ctx, std::vector & lora) { std::vector loras; std::vector scales; @@ -2081,3 +2098,11 @@ void common_prompt_checkpoint::load_dft( GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", data_dft.size(), n); } } + +void common_prompt_checkpoint::clear_tgt() { + data_tgt.clear(); +} + +void common_prompt_checkpoint::clear_dft() { + data_dft.clear(); +} diff --git a/common/common.h b/common/common.h index 97977d580ad..ee6eb00e77a 100644 --- a/common/common.h +++ b/common/common.h @@ -904,6 +904,10 @@ enum common_context_seq_rm_type { // note: clears the memory of the context common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx); +// aborts execution on failure +void common_context_seq_rm (llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1); +void common_context_seq_add(llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta); +void common_context_seq_cp (llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1); // // Batch utils @@ -1085,4 +1089,7 @@ struct common_prompt_checkpoint { llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags) const; + + void clear_tgt(); + void clear_dft(); }; diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 109a77be404..084c5d9ea4f 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -170,10 +170,12 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos // partial rollback via per-token snapshot index (bounded by n_rs_seq) if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) { const llama_pos rollback = cell.pos - (p0 - 1); - GGML_ASSERT(rollback >= 1 && rollback <= (llama_pos) n_rs_seq); - set_rs_idx(seq_id, (uint32_t) rollback); - cell.pos = p0 - 1; - return true; + if (rollback >= 1 && rollback <= (llama_pos) n_rs_seq) { + set_rs_idx(seq_id, (uint32_t) rollback); + cell.pos = p0 - 1; + return true; + } + return false; } // invalidate tails which will be cleared if (p0 <= cell.pos && cell.pos < p1) { diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index f94da7d6442..fc92821940b 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -145,9 +145,9 @@ struct server_slot { SLT_INF(*this, "clearing prompt with %zu tokens\n", prompt.tokens.size()); - llama_memory_seq_rm(llama_get_memory(ctx_tgt), id, -1, -1); + common_context_seq_rm(ctx_tgt, id, -1, -1); if (ctx_dft) { - llama_memory_seq_rm(llama_get_memory(ctx_dft), id, -1, -1); + common_context_seq_rm(ctx_dft, id, -1, -1); } prompt.tokens.clear(); @@ -517,12 +517,12 @@ struct server_slot { void copy_state_to(server_slot & other) const { GGML_ASSERT(state == SLOT_STATE_DONE_PROMPT); - llama_memory_seq_rm(llama_get_memory(ctx_tgt), other.id, -1, -1); - llama_memory_seq_cp(llama_get_memory(ctx_tgt), id, other.id, -1, -1); + common_context_seq_rm(ctx_tgt, other.id, -1, -1); + common_context_seq_cp(ctx_tgt, id, other.id, -1, -1); if (ctx_dft) { - llama_memory_seq_rm(llama_get_memory(ctx_dft), other.id, -1, -1); - llama_memory_seq_cp(llama_get_memory(ctx_dft), id, other.id, -1, -1); + common_context_seq_rm(ctx_dft, other.id, -1, -1); + common_context_seq_cp(ctx_dft, id, other.id, -1, -1); } other.n_decoded = n_decoded; @@ -789,6 +789,8 @@ struct server_context_impl { cparams.ctx_type = LLAMA_CONTEXT_TYPE_MTP; } + // note: for small models maybe we can set this to the maximum possible draft from all speculative types + // the extra memory for small models is likely negligible? cparams.n_rs_seq = 0; ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams)); @@ -803,6 +805,7 @@ struct server_context_impl { auto cparams_mtp = common_context_params_to_llama(params_base); cparams_mtp.ctx_type = LLAMA_CONTEXT_TYPE_MTP; + cparams_mtp.n_rs_seq = 0; ctx_dft.reset(llama_init_from_model(model_tgt, cparams_mtp)); if (ctx_dft == nullptr) { @@ -2227,12 +2230,12 @@ struct server_context_impl { SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard); - llama_memory_seq_rm (llama_get_memory(ctx_tgt), slot.id, n_keep , n_keep + n_discard); - llama_memory_seq_add(llama_get_memory(ctx_tgt), slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard); + common_context_seq_rm (ctx_tgt, slot.id, n_keep , n_keep + n_discard); + common_context_seq_add(ctx_tgt, slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard); if (ctx_dft) { - llama_memory_seq_rm (llama_get_memory(ctx_dft.get()), slot.id, n_keep , n_keep + n_discard); - llama_memory_seq_add(llama_get_memory(ctx_dft.get()), slot.id, n_keep + n_discard, slot.prompt.tokens.pos_next(), -n_discard); + common_context_seq_rm (ctx_dft.get(), slot.id, n_keep , n_keep + n_discard); + common_context_seq_add(ctx_dft.get(), slot.id, n_keep + n_discard, slot.prompt.tokens.pos_next(), -n_discard); } // add generated tokens to cache @@ -2339,14 +2342,23 @@ struct server_context_impl { slot.n_draft_total += draft.size(); // TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL] + const bool use_ckpt_dft = ctx_dft_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; + if (ctx_dft) { - ckpt.load_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + if (use_ckpt_dft) { + ckpt.load_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + } - llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), slot.id, ckpt.pos_max + 1, -1); + common_context_seq_rm(ctx_dft.get(), slot.id, ckpt.pos_max + 1, -1); } if (!draft.empty()) { - const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; + const bool use_ckpt_tgt = + ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL || + (ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_RS && draft.size() > llama_n_rs_seq(ctx_tgt)); + + const bool use_ckpt_dft = + (ctx_dft_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_RS && draft.size() > llama_n_rs_seq(ctx_dft.get())); if (use_ckpt_tgt) { //const int64_t t_start = ggml_time_us(); @@ -2361,6 +2373,10 @@ struct server_context_impl { (float) ckpt.size() / 1024 / 1024, (float) ckpt.data_dft.size() / 1024 / 1024); } + + if (use_ckpt_dft) { + ckpt.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + } } } @@ -2532,12 +2548,12 @@ struct server_context_impl { const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c; - llama_memory_seq_rm (llama_get_memory(ctx_tgt), slot.id, head_p, head_c); - llama_memory_seq_add(llama_get_memory(ctx_tgt), slot.id, head_c, head_c + n_match, kv_shift); + common_context_seq_rm (ctx_tgt, slot.id, head_p, head_c); + common_context_seq_add(ctx_tgt, slot.id, head_c, head_c + n_match, kv_shift); if (ctx_dft) { - llama_memory_seq_rm (llama_get_memory(ctx_dft.get()), slot.id, head_p, head_c); - llama_memory_seq_add(llama_get_memory(ctx_dft.get()), slot.id, head_c, head_c + n_match, kv_shift); + common_context_seq_rm (ctx_dft.get(), slot.id, head_p, head_c); + common_context_seq_add(ctx_dft.get(), slot.id, head_c, head_c + n_match, kv_shift); } for (size_t i = 0; i < n_match; i++) { @@ -2700,18 +2716,10 @@ struct server_context_impl { SLT_TRC(slot, "cached n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0); - if (!llama_memory_seq_rm(llama_get_memory(ctx_tgt), slot.id, p0, -1)) { - SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0); - - slot.prompt_clear(true); - - // there is no common part left - slot.n_prompt_tokens_cache = 0; - } else { - if (ctx_dft && !llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), slot.id, p0, -1)) { - GGML_ABORT("failed to truncate draft context\n"); - } - } + common_context_seq_rm(ctx_tgt, slot.id, p0, -1); + if (ctx_dft) { + common_context_seq_rm(ctx_dft.get(), slot.id, p0, -1); + } // If using an alora, there may be uncached tokens that come // before the invocation sequence. When this happens, the @@ -3177,13 +3185,8 @@ struct server_context_impl { // verify and try to accept the draft { - const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; - - // only save the sampler sampler state if we use checkpoints - common_sampler_ptr smpl_save; - if (use_ckpt_tgt) { - smpl_save.reset(common_sampler_clone(slot.smpl.get())); - } + // save the sampler sampler state in case we need to restore it + common_sampler_ptr smpl_save(common_sampler_clone(slot.smpl.get())); GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1); auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx_tgt, slot.spec_i_batch, slot.spec_draft); @@ -3191,8 +3194,14 @@ struct server_context_impl { GGML_ASSERT(accepted.size() >= 1); + const uint32_t n_rollback = slot.spec_draft.size() + 1 - accepted.size(); + + const bool use_ckpt_tgt = + ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL || + (ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_RS && n_rollback > llama_n_rs_seq(ctx_tgt)); + // check for partial draft acceptance - if (accepted.size() < slot.spec_draft.size() + 1) { + if (n_rollback > 0) { if (use_ckpt_tgt) { if (trace > 0) { SLT_INF(slot, "accepted %2zu/%2zu draft tokens (restore checkpoint)\n", accepted.size() - 1, slot.spec_draft.size()); @@ -3208,13 +3217,13 @@ struct server_context_impl { { ckpt.load_tgt(slot.ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); - llama_memory_seq_rm(llama_get_memory(slot.ctx_tgt), slot.id, ckpt.pos_max + 1, -1); + common_context_seq_rm(slot.ctx_tgt, slot.id, ckpt.pos_max + 1, -1); } if (slot.ctx_dft) { ckpt.load_dft(slot.ctx_dft, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); - llama_memory_seq_rm(llama_get_memory(slot.ctx_dft), slot.id, ckpt.pos_max + 1, -1); + common_context_seq_rm(slot.ctx_dft, slot.id, ckpt.pos_max + 1, -1); } slot.prompt.tokens.keep_first(ckpt.n_tokens); @@ -3250,9 +3259,9 @@ struct server_context_impl { slot.sampled = ids.back(); // last accepted token SLT_DBG(slot, "add accepted tokens: sampled=%d, ids.size=%zu, n_draft=%zu\n", slot.sampled, ids.size(), n_draft); - llama_memory_seq_rm(llama_get_memory(slot.ctx_tgt), slot.id, slot.prompt.tokens.pos_next(), -1); + common_context_seq_rm(slot.ctx_tgt, slot.id, slot.prompt.tokens.pos_next(), -1); if (slot.ctx_dft) { - llama_memory_seq_rm(llama_get_memory(slot.ctx_dft), slot.id, slot.prompt.tokens.pos_next(), -1); + common_context_seq_rm(slot.ctx_dft, slot.id, slot.prompt.tokens.pos_next(), -1); } for (size_t i = 0; i < ids.size(); ++i) { From 749a0b2bdfc0c4dc933c0f3c7edf39ab12812349 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 14 May 2026 23:40:15 +0800 Subject: [PATCH 20/28] server-context: fix early exit --- tools/server/server-context.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index fc92821940b..4d162f81d9b 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -3246,7 +3246,6 @@ struct server_context_impl { const auto ids = std::move(slot.spec_draft); - slot.n_decoded += ids.size(); slot.t_token_generation = std::max(1, t_current - slot.t_start_generation) / 1e3; // update how many tokens out of those tested were accepted @@ -3273,6 +3272,8 @@ struct server_context_impl { // TODO: set result.probs + slot.n_decoded += 1; + if (!process_token(result, slot)) { slot.print_timings(); send_final_response(slot); From d42d25de11be671e234025edbe66179dcf69fac7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 15 May 2026 13:51:12 +0300 Subject: [PATCH 21/28] spec : fix compatibility with n-gram and add TODOs (#13) * metal : cleanup * llama : fix faulty bitwise check in recurrent memory * server : disable RS-based MTP in combination with other spec types * spec : add TODOs * cont : fix comment * cont : update comment * common : fix logic for ngram + mtp compat --- common/common.cpp | 24 ++++++++++++++++++++++++ common/common.h | 2 +- common/speculative.cpp | 7 ++++++- ggml/src/ggml-metal/ggml-metal.metal | 6 +++--- src/llama-memory-recurrent.cpp | 11 ++++++++++- 5 files changed, 44 insertions(+), 6 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 9f597b4feb8..8b6d182f549 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -7,6 +7,7 @@ #include "log.h" #include "llama.h" #include "sampling.h" +#include "speculative.h" #include "unicode.h" #include @@ -1247,6 +1248,29 @@ common_init_result::common_init_result(common_params & params) : cparams.n_samplers = pimpl->samplers_seq_config.size(); } + // [TAG_RS_STATE_ROLLBACK_SUPPORT] + // TODO: ngram speculative methods require checkpointing in addition to partial RS rollback + // currently this is not supported. so we disable the partial rollback + if (cparams.n_rs_seq > 0 && (llama_model_is_recurrent(model) || llama_model_is_hybrid(model))) { + auto & types = params.speculative.types; + + for (int i = 0; i < (int) types.size(); i++) { + if (types[i] == COMMON_SPECULATIVE_TYPE_NONE) { + continue; + } + if (types[i] == COMMON_SPECULATIVE_TYPE_DRAFT_MTP) { + continue; + } + + cparams.n_rs_seq = 0; + + LOG_WRN("%s: recurrent state rollback is not compatible with '%s' - disabling rollback support\n", __func__, + common_speculative_type_to_str(types[i]).c_str()); + + break; + } + } + llama_context * lctx = llama_init_from_model(model, cparams); if (lctx == NULL) { LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str()); diff --git a/common/common.h b/common/common.h index ee6eb00e77a..4cca9d71568 100644 --- a/common/common.h +++ b/common/common.h @@ -303,7 +303,7 @@ struct common_params_speculative_draft { int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding float p_split = 0.1f; // speculative decoding split probability - float p_min = 0.75f; // minimum speculative decoding probability (greedy) + float p_min = 0.75f; // minimum speculative decoding probability (greedy) // TODO: change default to 0.0f common_params_model mparams; diff --git a/common/speculative.cpp b/common/speculative.cpp index 0a78bc1136c..3488b9393c5 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -424,7 +424,7 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl { for (auto & s : smpls) { common_params_sampling sparams; sparams.no_perf = false; - sparams.top_k = 1; + sparams.top_k = 1; // TODO: re-enable top_k == 10 and utilize `p_min` spec param sparams.samplers = { COMMON_SAMPLER_TYPE_TOP_K }; s.reset(common_sampler_init(llama_get_model(ctx_dft), sparams)); } @@ -1494,6 +1494,11 @@ void common_speculative_accept(common_speculative * spec, llama_seq_id seq_id, u GGML_ASSERT(impl); + // TODO: currently only the implementation that generated the draft is used to accept it + // however, some implementations (such as MTP) need to also "see" the accepted tokens + // extend `common_speculative_impl::accept()` with an extra argument `bool is_other` to + // inform the implementation if the accepted tokens are from another implementation and + // pass the accepted tokens to all remaining implementations using `is_other == true` { common_time_meas tm(impl->t_accept_us, !impl->gen_perf); if (n_accepted > 0) { diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index deb616105ae..82e29d5ad7c 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2531,7 +2531,7 @@ kernel void kernel_rwkv_wkv7_f32( constant short FC_gated_delta_net_ne20 [[function_constant(FC_GATED_DELTA_NET + 0)]]; constant short FC_gated_delta_net_ne30 [[function_constant(FC_GATED_DELTA_NET + 1)]]; -constant short FC_gated_delta_net_K [[function_constant(FC_GATED_DELTA_NET + 2)]]; +constant short FC_gated_delta_net_K [[function_constant(FC_GATED_DELTA_NET + 2)]]; #if 1 template @@ -2549,6 +2549,7 @@ kernel void kernel_gated_delta_net_impl( uint3 ntg[[threads_per_threadgroup]]) { #define S_v FC_gated_delta_net_ne20 #define G FC_gated_delta_net_ne30 +#define K FC_gated_delta_net_K const uint tx = tpitg.x; const uint ty = tpitg.y; @@ -2562,8 +2563,6 @@ kernel void kernel_gated_delta_net_impl( const float scale = 1.0f / sqrt((float)S_v); - const uint K = FC_gated_delta_net_K; - // input state layout (D, K, n_seqs): per-seq stride is K*H*D; we read slot 0. // state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous const uint state_in_base = (i23*K*args.ne21 + i21)*S_v*S_v + i20*S_v; @@ -2666,6 +2665,7 @@ kernel void kernel_gated_delta_net_impl( #undef S_v #undef G +#undef K } typedef decltype(kernel_gated_delta_net_impl<4>) kernel_gated_delta_net_t; diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 084c5d9ea4f..64c7e726fdc 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -719,6 +719,15 @@ size_t llama_memory_recurrent::size_s_bytes() const { void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { GGML_UNUSED(flags); + // [TAG_RS_STATE_ROLLBACK_SUPPORT] + if (n_rs_seq != 0) { + for (uint32_t i = 0; i < rs_idx.size(); ++i) { + if (rs_idx[i] != 0) { + GGML_ABORT("recurrent state read/write is not supported with partial rollback"); + } + } + } + std::vector> cell_ranges; // ranges, from inclusive, to exclusive uint32_t cell_count = 0; @@ -743,7 +752,7 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq cell_ranges.emplace_back(cell_range_begin, size); } - if (flags % LLAMA_STATE_SEQ_FLAGS_ON_DEVICE && cell_ranges.size() > 1) { + if ((flags & LLAMA_STATE_SEQ_FLAGS_ON_DEVICE) && cell_ranges.size() > 1) { GGML_ABORT("cannot save/load multiple ranges of cells to/from device memory\n"); } From cddbb7fd914838c645a43d6a2507809c01fdbfa6 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sat, 16 May 2026 01:56:52 +0800 Subject: [PATCH 22/28] llama-memory: enable checkpointing with partial rollback --- src/llama-memory-recurrent.cpp | 73 ++++++++++--- tests/CMakeLists.txt | 3 + tests/test-recurrent-state-rollback.cpp | 138 ++++++++++++++++++++++++ 3 files changed, 200 insertions(+), 14 deletions(-) create mode 100644 tests/test-recurrent-state-rollback.cpp diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 64c7e726fdc..49a80b61641 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -156,6 +156,15 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1 = std::numeric_limits::max(); } + const bool rm_all = p0 == 0 && p1 == std::numeric_limits::max(); + if (rm_all) { + if (seq_id >= 0) { + set_rs_idx(seq_id, 0); + } else { + std::fill(rs_idx.begin(), rs_idx.end(), 0); + } + } + // models like Mamba or RWKV can't have a state partially erased at the end // of the sequence because their state isn't preserved for previous tokens if (seq_id >= (int64_t) size) { @@ -719,16 +728,8 @@ size_t llama_memory_recurrent::size_s_bytes() const { void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { GGML_UNUSED(flags); - // [TAG_RS_STATE_ROLLBACK_SUPPORT] - if (n_rs_seq != 0) { - for (uint32_t i = 0; i < rs_idx.size(); ++i) { - if (rs_idx[i] != 0) { - GGML_ABORT("recurrent state read/write is not supported with partial rollback"); - } - } - } - std::vector> cell_ranges; // ranges, from inclusive, to exclusive + std::vector> cell_ranges_data; // logical source row ranges uint32_t cell_count = 0; // Count the number of cells with the specified seq_id @@ -738,6 +739,35 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq const auto & cell = cells[i]; if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) { ++cell_count; + uint32_t rs_idx_cur = 0; + + if (n_rs_seq != 0) { + if (seq_id != -1) { + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < rs_idx.size()); + rs_idx_cur = rs_idx[seq_id]; + } else { + bool has_rs_idx = false; + for (const llama_seq_id cell_seq_id : cell.seq_id) { + GGML_ASSERT(cell_seq_id >= 0 && (size_t) cell_seq_id < rs_idx.size()); + + const uint32_t seq_rs_idx = rs_idx[cell_seq_id]; + if (!has_rs_idx) { + rs_idx_cur = seq_rs_idx; + has_rs_idx = true; + } else if (rs_idx_cur != seq_rs_idx) { + GGML_ABORT("cannot write shared recurrent state with different rollback indices"); + } + } + } + } + + const uint32_t cell_id = rs_idx_cur * size + (cell.src >= 0 ? cell.src : (int32_t) i); + if (cell_ranges_data.empty() || cell_ranges_data.back().second != cell_id) { + cell_ranges_data.emplace_back(cell_id, cell_id + 1); + } else { + cell_ranges_data.back().second++; + } + if (cell_range_begin == size) { cell_range_begin = i; } @@ -763,10 +793,16 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq } GGML_ASSERT(cell_count == cell_count_check); + cell_count_check = 0; + for (const auto & range : cell_ranges_data) { + cell_count_check += range.second - range.first; + } + GGML_ASSERT(cell_count == cell_count_check); + io.write(&cell_count, sizeof(cell_count)); state_write_meta(io, cell_ranges, seq_id); - state_write_data(io, cell_ranges); + state_write_data(io, cell_ranges_data); } void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { @@ -788,6 +824,14 @@ void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_i } throw std::runtime_error("failed to restore kv cache"); } + + if (n_rs_seq != 0) { + if (seq_id == -1) { + std::fill(rs_idx.begin(), rs_idx.end(), 0); + } else { + set_rs_idx(seq_id, 0); + } + } } void llama_memory_recurrent::state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id) const { @@ -830,7 +874,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: const uint64_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r()); io.write(&r_size_row, sizeof(r_size_row)); - // Write each range of cells of r_size_row length + // Write each logical cell row range. With pending recurrent rollback, + // the logical current state may live in a rollback snapshot plane. for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * r_size_row; @@ -851,7 +896,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: const uint64_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s()); io.write(&s_size_row, sizeof(s_size_row)); - // Write each range of S tensor rows + // Write each logical cell row range. With pending recurrent rollback, + // the logical current state may live in a rollback snapshot plane. for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * s_size_row; @@ -878,9 +924,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: // Write GQA embedding size io.write(&n_embd_s, sizeof(n_embd_s)); - // For each row, we get the element values of each cell + // For each row, we get the element values of each logical cell for (uint32_t j = 0; j < n_embd_s; ++j) { - // Write each range of cells of s_size_el length for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; const size_t src_offset = (range.first + j * mem_size) * s_size_el; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 3ee535224d9..0fdbd39c94a 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -252,6 +252,9 @@ llama_build_and_test(test-backend-sampler.cpp LABEL "model") llama_build_and_test(test-state-restore-fragmented.cpp LABEL "model" ARGS -m "${MODEL_DEST}") set_tests_properties(test-state-restore-fragmented PROPERTIES FIXTURES_REQUIRED test-download-model) +llama_build_and_test(test-recurrent-state-rollback.cpp LABEL "model" ARGS -m "${MODEL_DEST}") +set_tests_properties(test-recurrent-state-rollback PROPERTIES FIXTURES_REQUIRED test-download-model) + if (NOT GGML_BACKEND_DL) # these tests use the backends directly and cannot be built with dynamic loading llama_build_and_test(test-barrier.cpp) diff --git a/tests/test-recurrent-state-rollback.cpp b/tests/test-recurrent-state-rollback.cpp new file mode 100644 index 00000000000..8c465c0f831 --- /dev/null +++ b/tests/test-recurrent-state-rollback.cpp @@ -0,0 +1,138 @@ +#include "arg.h" +#include "common.h" +#include "llama.h" + +#include +#include +#include +#include +#include + +static llama_context * make_ctx(const common_params & params, llama_model * model) { + auto cparams = common_context_params_to_llama(params); + cparams.n_seq_max = 1; + cparams.n_rs_seq = 8; + cparams.n_batch = std::max(cparams.n_batch, (uint32_t) (cparams.n_rs_seq + 1)); + cparams.n_ubatch = std::max(cparams.n_ubatch, (uint32_t) (cparams.n_rs_seq + 1)); + return llama_init_from_model(model, cparams); +} + +static bool decode_tokens(llama_context * ctx, const std::vector & tokens, uint32_t count) { + llama_batch batch = llama_batch_init(count, 0, 1); + for (uint32_t pos = 0; pos < count; ++pos) { + common_batch_add(batch, tokens[pos], pos, { 0 }, false); + } + const bool ok = llama_decode(ctx, batch) == 0; + llama_batch_free(batch); + return ok; +} + +static bool decode_one(llama_context * ctx, llama_token tok, llama_pos pos) { + llama_batch batch = llama_batch_init(1, 0, 1); + common_batch_add(batch, tok, pos, { 0 }, true); + const bool ok = llama_decode(ctx, batch) == 0; + llama_batch_free(batch); + return ok; +} + +int main(int argc, char ** argv) { + std::setlocale(LC_NUMERIC, "C"); + + common_params params; + params.sampling.seed = 1234; + params.n_predict = 1; + + common_init(); + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) { + return 1; + } + + ggml_backend_load_all(); + + common_init_result_ptr llama_init = common_init_from_params(params); + llama_model * model = llama_init->model(); + if (model == nullptr) { + fprintf(stderr, "%s : failed to init model\n", __func__); + return 1; + } + + if (!llama_model_is_recurrent(model) && !llama_model_is_hybrid(model)) { + fprintf(stderr, "%s : skipping for non-recurrent model\n", __func__); + return 0; + } + + const llama_vocab * vocab = llama_model_get_vocab(model); + const int n_vocab = llama_vocab_n_tokens(vocab); + + llama_context * ctx_src = make_ctx(params, model); + llama_context * ctx_dst = make_ctx(params, model); + if (ctx_src == nullptr || ctx_dst == nullptr) { + fprintf(stderr, "%s : failed to init contexts\n", __func__); + return 1; + } + + if (llama_n_rs_seq(ctx_src) == 0) { + fprintf(stderr, "%s : skipping because n_rs_seq is disabled\n", __func__); + llama_free(ctx_src); + llama_free(ctx_dst); + return 0; + } + + std::vector tokens = common_tokenize(ctx_src, "The quick brown fox jumps", true); + const uint32_t n_rs_seq = llama_n_rs_seq(ctx_src); + if (tokens.size() > n_rs_seq + 1) { + tokens.resize(n_rs_seq + 1); + } + if (tokens.size() < 2) { + fprintf(stderr, "%s : not enough prompt tokens\n", __func__); + return 1; + } + const uint32_t n_tokens = tokens.size(); + const llama_token last_tok = tokens.back(); + const llama_pos last_pos = (llama_pos) n_tokens - 2; + + // Decode the full prompt on the source, then roll back the last position. + // Rollback leaves the recurrent memory in a snapshot state (rs_idx != 0). + if (!decode_tokens(ctx_src, tokens, n_tokens)) { + fprintf(stderr, "%s : failed to decode prompt\n", __func__); + return 1; + } + if (!llama_memory_seq_rm(llama_get_memory(ctx_src), 0, last_pos, -1)) { + fprintf(stderr, "%s : rollback failed\n", __func__); + return 1; + } + + // Save the rolled-back state and restore it into a fresh context. + common_prompt_checkpoint ckpt; + ckpt.update_tgt(ctx_src, 0, 0); + ckpt.load_tgt(ctx_dst, 0, 0); + + // Replay the rolled-back token on both contexts and compare logits. + if (!decode_one(ctx_src, last_tok, last_pos) || + !decode_one(ctx_dst, last_tok, last_pos)) { + fprintf(stderr, "%s : replay failed\n", __func__); + return 1; + } + + const float * logits_src = llama_get_logits_ith(ctx_src, 0); + const float * logits_dst = llama_get_logits_ith(ctx_dst, 0); + if (logits_src == nullptr || logits_dst == nullptr) { + fprintf(stderr, "%s : missing logits\n", __func__); + return 1; + } + + constexpr float eps = 1e-5f; + for (int i = 0; i < n_vocab; ++i) { + if (std::fabs(logits_src[i] - logits_dst[i]) > eps) { + fprintf(stderr, "%s : logits mismatch at token %d (%g != %g)\n", + __func__, i, (double) logits_src[i], (double) logits_dst[i]); + return 1; + } + } + + fprintf(stderr, "%s : recurrent rollback checkpoint restored successfully\n", __func__); + llama_free(ctx_src); + llama_free(ctx_dst); + return 0; +} From 6ef79f7c19c18b7145f045107b0775e8949d72a0 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sat, 16 May 2026 11:45:33 +0800 Subject: [PATCH 23/28] cont: add test-case for loading into a dirty ctx --- tests/test-recurrent-state-rollback.cpp | 47 +++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/tests/test-recurrent-state-rollback.cpp b/tests/test-recurrent-state-rollback.cpp index 8c465c0f831..be19316db8a 100644 --- a/tests/test-recurrent-state-rollback.cpp +++ b/tests/test-recurrent-state-rollback.cpp @@ -131,8 +131,55 @@ int main(int argc, char ** argv) { } } + // Repeat the load into a context that already has its own rollback state: + // groups 1..n_rs_seq hold a *different* prompt's history, and rs_idx[0] is + // non-zero at load time. The restore must wipe that state and still match. + llama_context * ctx_dirty = make_ctx(params, model); + if (ctx_dirty == nullptr) { + fprintf(stderr, "%s : failed to init dirty ctx\n", __func__); + return 1; + } + + std::vector noise = tokens; + for (auto & t : noise) { + t = (t + 1) % n_vocab; + if (t < 0) { + t = 0; + } + } + if (!decode_tokens(ctx_dirty, noise, n_tokens)) { + fprintf(stderr, "%s : dirty prompt decode failed\n", __func__); + return 1; + } + if (!llama_memory_seq_rm(llama_get_memory(ctx_dirty), 0, last_pos, -1)) { + fprintf(stderr, "%s : dirty rollback failed\n", __func__); + return 1; + } + + ckpt.load_tgt(ctx_dirty, 0, 0); + + if (!decode_one(ctx_dirty, last_tok, last_pos)) { + fprintf(stderr, "%s : dirty replay failed\n", __func__); + return 1; + } + + const float * logits_dirty = llama_get_logits_ith(ctx_dirty, 0); + if (logits_dirty == nullptr) { + fprintf(stderr, "%s : missing dirty logits\n", __func__); + return 1; + } + + for (int i = 0; i < n_vocab; ++i) { + if (std::fabs(logits_src[i] - logits_dirty[i]) > eps) { + fprintf(stderr, "%s : dirty-ctx logits mismatch at token %d (%g != %g)\n", + __func__, i, (double) logits_src[i], (double) logits_dirty[i]); + return 1; + } + } + fprintf(stderr, "%s : recurrent rollback checkpoint restored successfully\n", __func__); llama_free(ctx_src); llama_free(ctx_dst); + llama_free(ctx_dirty); return 0; } From 0f6f0d6b2a9be804e330588b20053bd87e345632 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sat, 16 May 2026 17:40:53 +0800 Subject: [PATCH 24/28] llama-memory-recurrent: clear rs_idx in clear --- conversion/base.py | 5 +++++ conversion/qwen.py | 3 --- src/llama-memory-recurrent.cpp | 2 ++ tests/test-llama-archs.cpp | 3 ++- 4 files changed, 9 insertions(+), 4 deletions(-) diff --git a/conversion/base.py b/conversion/base.py index 3c4be034154..30c2124c2b9 100644 --- a/conversion/base.py +++ b/conversion/base.py @@ -107,6 +107,11 @@ class ModelBase: disable_mistral_community_chat_template: bool = False sentence_transformers_dense_modules: bool = False + # MTP (multi-token prediction) export modes; set by main() before instantiation. + # Architectures opt in by overriding the handling (see _Qwen35MtpMixin). + mtp_only: bool = False + no_mtp: bool = False + def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False, use_temp_file: bool = False, eager: bool = False, metadata_override: Path | None = None, model_name: str | None = None, diff --git a/conversion/qwen.py b/conversion/qwen.py index 78c8293b86b..6949ab979c8 100644 --- a/conversion/qwen.py +++ b/conversion/qwen.py @@ -548,9 +548,6 @@ class _Qwen35MtpMixin: block_count: int tensor_map: gguf.TensorNameMap - mtp_only: bool = False - no_mtp: bool = False - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.block_count = self.hparams["num_hidden_layers"] diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 49a80b61641..aeb866657f2 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -143,6 +143,8 @@ void llama_memory_recurrent::clear(bool data) { ggml_backend_buffer_clear(buf.get(), 0); } } + + std::fill(rs_idx.begin(), rs_idx.end(), 0); } bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { diff --git a/tests/test-llama-archs.cpp b/tests/test-llama-archs.cpp index 03d7c19c78b..16af11a2862 100644 --- a/tests/test-llama-archs.cpp +++ b/tests/test-llama-archs.cpp @@ -406,7 +406,8 @@ static bool arch_supported(const llm_arch arch) { if (arch == LLM_ARCH_DEEPSEEK2OCR) { return false; } -// FIXME some models are segfaulting with WebGPU: + + // FIXME some models are segfaulting with WebGPU: #ifdef GGML_USE_WEBGPU if (arch == LLM_ARCH_QWEN3NEXT || arch == LLM_ARCH_QWEN35 || arch == LLM_ARCH_QWEN35MOE || arch == LLM_ARCH_KIMI_LINEAR) { return false; From 37a479f72b5d0fadf4eb4684be85371dab668d84 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sat, 16 May 2026 18:12:02 +0800 Subject: [PATCH 25/28] download: fix mtp path --- common/download.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/download.cpp b/common/download.cpp index f3dacb7e3e0..103bc408faf 100644 --- a/common/download.cpp +++ b/common/download.cpp @@ -975,7 +975,7 @@ std::vector common_list_cached_models() { auto split = get_gguf_split_info(f.path); if (split.index != 1 || split.tag.empty() || split.prefix.find("mmproj") != std::string::npos || - split.prefix.find("MTP") != std::string::npos) { + split.prefix.find("mtp-") != std::string::npos) { continue; } if (seen.insert(f.repo_id + ":" + split.tag).second) { From 8e9a07d987d0b2c63a986275d77bceb42d5bce5a Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sat, 16 May 2026 18:15:09 +0800 Subject: [PATCH 26/28] llama-arch: fix enorm op --- src/llama-arch.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 4bee6fbe651..c9eead18aa3 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -762,7 +762,7 @@ static const std::map LLM_TENSOR_INFOS = { // the model loader doesn't fault on the block index. {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, From 5a818cd3a376d516763ee5567716ff289cf09a67 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sat, 16 May 2026 18:25:44 +0800 Subject: [PATCH 27/28] docs: update docs --- tools/cli/README.md | 7 +++---- tools/completion/README.md | 5 ++--- tools/server/README.md | 19 +++++++++++-------- 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/tools/cli/README.md b/tools/cli/README.md index 9f0574d25d3..c40b5a21cc0 100644 --- a/tools/cli/README.md +++ b/tools/cli/README.md @@ -55,7 +55,6 @@ | `-ctv, --cache-type-v TYPE` | KV cache data type for V
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_V) | | `-dt, --defrag-thold N` | KV cache defragmentation threshold (DEPRECATED)
(env: LLAMA_ARG_DEFRAG_THOLD) | | `-np, --parallel N` | number of parallel sequences to decode (default: 1)
(env: LLAMA_ARG_N_PARALLEL) | -| `--rpc SERVERS` | comma-separated list of RPC servers (host:port)
(env: LLAMA_ARG_RPC) | | `--mlock` | force system to keep model in RAM rather than swapping or compressing
(env: LLAMA_ARG_MLOCK) | | `--mmap, --no-mmap` | whether to memory-map model. (if mmap disabled, slower load but may reduce pageouts if not using mlock) (default: enabled)
(env: LLAMA_ARG_MMAP) | | `-dio, --direct-io, -ndio, --no-direct-io` | use DirectIO if available. (default: disabled)
(env: LLAMA_ARG_DIO) | @@ -94,8 +93,8 @@ | `-v, --verbose, --log-verbose` | Set verbosity level to infinity (i.e. log all messages, useful for debugging) | | `--offline` | Offline mode: forces use of cache, prevents network access
(env: LLAMA_OFFLINE) | | `-lv, --verbosity, --log-verbosity N` | Set the verbosity threshold. Messages with a higher verbosity will be ignored. Values:
- 0: generic output
- 1: error
- 2: warning
- 3: info
- 4: debug
(default: 3)

(env: LLAMA_LOG_VERBOSITY) | -| `--log-prefix` | Enable prefix in log messages
(env: LLAMA_LOG_PREFIX) | -| `--log-timestamps` | Enable timestamps in log messages
(env: LLAMA_LOG_TIMESTAMPS) | +| `--log-prefix, --no-log-prefix` | Enable prefix in log messages
(env: LLAMA_ARG_LOG_PREFIX) | +| `--log-timestamps, --no-log-timestamps` | Enable timestamps in log messages
(env: LLAMA_ARG_LOG_TIMESTAMPS) | | `--spec-draft-type-k, -ctkd, --cache-type-k-draft TYPE` | KV cache data type for K for the draft model
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_SPEC_DRAFT_CACHE_TYPE_K) | | `--spec-draft-type-v, -ctvd, --cache-type-v-draft TYPE` | KV cache data type for V for the draft model
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_SPEC_DRAFT_CACHE_TYPE_V) | @@ -199,7 +198,7 @@ | `--spec-draft-device, -devd, --device-draft ` | comma-separated list of devices to use for offloading the draft model (none = don't offload)
use --list-devices to see a list of available devices | | `--spec-draft-ngl, -ngld, --gpu-layers-draft, --n-gpu-layers-draft N` | max. number of draft model layers to store in VRAM, either an exact number, 'auto', or 'all' (default: auto)
(env: LLAMA_ARG_N_GPU_LAYERS_DRAFT) | | `--spec-draft-model, -md, --model-draft FNAME` | draft model for speculative decoding (default: unused)
(env: LLAMA_ARG_SPEC_DRAFT_MODEL) | -| `--spec-type none,draft-simple,draft-eagle3,ngram-simple,ngram-map-k,ngram-map-k4v,ngram-mod,ngram-cache` | comma-separated list of types of speculative decoding to use (default: none)

(env: LLAMA_ARG_SPEC_TYPE) | +| `--spec-type none,draft-simple,draft-eagle3,draft-mtp,ngram-simple,ngram-map-k,ngram-map-k4v,ngram-mod,ngram-cache` | comma-separated list of types of speculative decoding to use (default: none)

(env: LLAMA_ARG_SPEC_TYPE) | | `--spec-ngram-mod-n-min N` | minimum number of ngram tokens to use for ngram-based speculative decoding (default: 48) | | `--spec-ngram-mod-n-max N` | maximum number of ngram tokens to use for ngram-based speculative decoding (default: 64) | | `--spec-ngram-mod-n-match N` | ngram-mod lookup length (default: 24) | diff --git a/tools/completion/README.md b/tools/completion/README.md index 048cf7416fc..e5dd7f6f4e7 100644 --- a/tools/completion/README.md +++ b/tools/completion/README.md @@ -138,7 +138,6 @@ llama-completion.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --ignore-eos -n -1 | `-ctv, --cache-type-v TYPE` | KV cache data type for V
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_V) | | `-dt, --defrag-thold N` | KV cache defragmentation threshold (DEPRECATED)
(env: LLAMA_ARG_DEFRAG_THOLD) | | `-np, --parallel N` | number of parallel sequences to decode (default: 1)
(env: LLAMA_ARG_N_PARALLEL) | -| `--rpc SERVERS` | comma-separated list of RPC servers (host:port)
(env: LLAMA_ARG_RPC) | | `--mlock` | force system to keep model in RAM rather than swapping or compressing
(env: LLAMA_ARG_MLOCK) | | `--mmap, --no-mmap` | whether to memory-map model. (if mmap disabled, slower load but may reduce pageouts if not using mlock) (default: enabled)
(env: LLAMA_ARG_MMAP) | | `-dio, --direct-io, -ndio, --no-direct-io` | use DirectIO if available. (default: disabled)
(env: LLAMA_ARG_DIO) | @@ -177,8 +176,8 @@ llama-completion.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --ignore-eos -n -1 | `-v, --verbose, --log-verbose` | Set verbosity level to infinity (i.e. log all messages, useful for debugging) | | `--offline` | Offline mode: forces use of cache, prevents network access
(env: LLAMA_OFFLINE) | | `-lv, --verbosity, --log-verbosity N` | Set the verbosity threshold. Messages with a higher verbosity will be ignored. Values:
- 0: generic output
- 1: error
- 2: warning
- 3: info
- 4: debug
(default: 3)

(env: LLAMA_LOG_VERBOSITY) | -| `--log-prefix` | Enable prefix in log messages
(env: LLAMA_LOG_PREFIX) | -| `--log-timestamps` | Enable timestamps in log messages
(env: LLAMA_LOG_TIMESTAMPS) | +| `--log-prefix, --no-log-prefix` | Enable prefix in log messages
(env: LLAMA_ARG_LOG_PREFIX) | +| `--log-timestamps, --no-log-timestamps` | Enable timestamps in log messages
(env: LLAMA_ARG_LOG_TIMESTAMPS) | | `--spec-draft-type-k, -ctkd, --cache-type-k-draft TYPE` | KV cache data type for K for the draft model
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_SPEC_DRAFT_CACHE_TYPE_K) | | `--spec-draft-type-v, -ctvd, --cache-type-v-draft TYPE` | KV cache data type for V for the draft model
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_SPEC_DRAFT_CACHE_TYPE_V) | diff --git a/tools/server/README.md b/tools/server/README.md index 2ed7fe16ee2..11098af2883 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -72,7 +72,6 @@ For the full list of features, please refer to [server's changelog](https://gith | `-ctk, --cache-type-k TYPE` | KV cache data type for K
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_K) | | `-ctv, --cache-type-v TYPE` | KV cache data type for V
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_V) | | `-dt, --defrag-thold N` | KV cache defragmentation threshold (DEPRECATED)
(env: LLAMA_ARG_DEFRAG_THOLD) | -| `--rpc SERVERS` | comma-separated list of RPC servers (host:port)
(env: LLAMA_ARG_RPC) | | `--mlock` | force system to keep model in RAM rather than swapping or compressing
(env: LLAMA_ARG_MLOCK) | | `--mmap, --no-mmap` | whether to memory-map model. (if mmap disabled, slower load but may reduce pageouts if not using mlock) (default: enabled)
(env: LLAMA_ARG_MMAP) | | `-dio, --direct-io, -ndio, --no-direct-io` | use DirectIO if available. (default: disabled)
(env: LLAMA_ARG_DIO) | @@ -111,8 +110,8 @@ For the full list of features, please refer to [server's changelog](https://gith | `-v, --verbose, --log-verbose` | Set verbosity level to infinity (i.e. log all messages, useful for debugging) | | `--offline` | Offline mode: forces use of cache, prevents network access
(env: LLAMA_OFFLINE) | | `-lv, --verbosity, --log-verbosity N` | Set the verbosity threshold. Messages with a higher verbosity will be ignored. Values:
- 0: generic output
- 1: error
- 2: warning
- 3: info
- 4: debug
(default: 3)

(env: LLAMA_LOG_VERBOSITY) | -| `--log-prefix` | Enable prefix in log messages
(env: LLAMA_LOG_PREFIX) | -| `--log-timestamps` | Enable timestamps in log messages
(env: LLAMA_LOG_TIMESTAMPS) | +| `--log-prefix, --no-log-prefix` | Enable prefix in log messages
(env: LLAMA_ARG_LOG_PREFIX) | +| `--log-timestamps, --no-log-timestamps` | Enable timestamps in log messages
(env: LLAMA_ARG_LOG_TIMESTAMPS) | | `--spec-draft-type-k, -ctkd, --cache-type-k-draft TYPE` | KV cache data type for K for the draft model
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_SPEC_DRAFT_CACHE_TYPE_K) | | `--spec-draft-type-v, -ctvd, --cache-type-v-draft TYPE` | KV cache data type for V for the draft model
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_SPEC_DRAFT_CACHE_TYPE_V) | @@ -189,11 +188,15 @@ For the full list of features, please refer to [server's changelog](https://gith | `--reuse-port` | allow multiple sockets to bind to the same port (default: disabled)
(env: LLAMA_ARG_REUSE_PORT) | | `--path PATH` | path to serve static files from (default: )
(env: LLAMA_ARG_STATIC_PATH) | | `--api-prefix PREFIX` | prefix path the server serves from, without the trailing slash (default: )
(env: LLAMA_ARG_API_PREFIX) | -| `--ui-config JSON` / `--webui-config JSON` (deprecated) | JSON that provides default UI settings (overrides UI defaults)
(env: LLAMA_ARG_UI_CONFIG / LLAMA_ARG_WEBUI_CONFIG) | -| `--ui-config-file PATH` / `--webui-config-file PATH` (deprecated) | JSON file that provides default UI settings (overrides UI defaults)
(env: LLAMA_ARG_UI_CONFIG_FILE / LLAMA_ARG_WEBUI_CONFIG_FILE) | -| `--ui-mcp-proxy, --no-ui-mcp-proxy` / `--webui-mcp-proxy, --no-webui-mcp-proxy` (deprecated) | experimental: whether to enable MCP CORS proxy - do not enable in untrusted environments (default: disabled)
(env: LLAMA_ARG_UI_MCP_PROXY / LLAMA_ARG_WEBUI_MCP_PROXY) | +| `--webui-config JSON` | [DEPRECATED: use --ui-config] JSON that provides default WebUI settings (overrides WebUI defaults)
(env: LLAMA_ARG_WEBUI_CONFIG) | +| `--ui-config JSON` | JSON that provides default UI settings (overrides UI defaults)
(env: LLAMA_ARG_UI_CONFIG) | +| `--webui-config-file PATH` | [DEPRECATED: use --ui-config-file] JSON file that provides default WebUI settings (overrides WebUI defaults)
(env: LLAMA_ARG_WEBUI_CONFIG_FILE) | +| `--ui-config-file PATH` | JSON file that provides default UI settings (overrides UI defaults)
(env: LLAMA_ARG_UI_CONFIG_FILE) | +| `--webui-mcp-proxy, --no-webui-mcp-proxy` | [DEPRECATED: use --ui-mcp-proxy/--no-ui-mcp-proxy] experimental: whether to enable MCP CORS proxy
(env: LLAMA_ARG_WEBUI_MCP_PROXY) | +| `--ui-mcp-proxy, --no-ui-mcp-proxy` | experimental: whether to enable MCP CORS proxy - do not enable in untrusted environments (default: disabled)
(env: LLAMA_ARG_UI_MCP_PROXY) | | `--tools TOOL1,TOOL2,...` | experimental: whether to enable built-in tools for AI agents - do not enable in untrusted environments (default: no tools)
specify "all" to enable all tools
available tools: read_file, file_glob_search, grep_search, exec_shell_command, write_file, edit_file, apply_diff, get_datetime
(env: LLAMA_ARG_TOOLS) | -| `--ui, --no-ui` / `--webui, --no-webui` (deprecated) | whether to enable the Web UI (default: enabled)
(env: LLAMA_ARG_UI / LLAMA_ARG_WEBUI) | +| `--webui, --no-webui` | [DEPRECATED: use --ui/--no-ui] whether to enable the Web UI
(env: LLAMA_ARG_WEBUI) | +| `--ui, --no-ui` | whether to enable the Web UI (default: enabled)
(env: LLAMA_ARG_UI) | | `--embedding, --embeddings` | restrict to only support embedding use case; use only with dedicated embedding models (default: disabled)
(env: LLAMA_ARG_EMBEDDINGS) | | `--rerank, --reranking` | enable reranking endpoint on server (default: disabled)
(env: LLAMA_ARG_RERANKING) | | `--api-key KEY` | API key to use for authentication, multiple keys can be provided as a comma-separated list (default: none)
(env: LLAMA_API_KEY) | @@ -248,7 +251,7 @@ For the full list of features, please refer to [server's changelog](https://gith | `--spec-draft-device, -devd, --device-draft ` | comma-separated list of devices to use for offloading the draft model (none = don't offload)
use --list-devices to see a list of available devices | | `--spec-draft-ngl, -ngld, --gpu-layers-draft, --n-gpu-layers-draft N` | max. number of draft model layers to store in VRAM, either an exact number, 'auto', or 'all' (default: auto)
(env: LLAMA_ARG_N_GPU_LAYERS_DRAFT) | | `--spec-draft-model, -md, --model-draft FNAME` | draft model for speculative decoding (default: unused)
(env: LLAMA_ARG_SPEC_DRAFT_MODEL) | -| `--spec-type none,draft-simple,draft-eagle3,ngram-simple,ngram-map-k,ngram-map-k4v,ngram-mod,ngram-cache` | comma-separated list of types of speculative decoding to use (default: none)

(env: LLAMA_ARG_SPEC_TYPE) | +| `--spec-type none,draft-simple,draft-eagle3,draft-mtp,ngram-simple,ngram-map-k,ngram-map-k4v,ngram-mod,ngram-cache` | comma-separated list of types of speculative decoding to use (default: none)

(env: LLAMA_ARG_SPEC_TYPE) | | `--spec-ngram-mod-n-min N` | minimum number of ngram tokens to use for ngram-based speculative decoding (default: 48) | | `--spec-ngram-mod-n-max N` | maximum number of ngram tokens to use for ngram-based speculative decoding (default: 64) | | `--spec-ngram-mod-n-match N` | ngram-mod lookup length (default: 24) | From 2dff7ff8f90ce6daefd6adb097d58a4276e5dd2d Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sat, 16 May 2026 18:35:48 +0800 Subject: [PATCH 28/28] conversion: fix type annotations --- conversion/qwen.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/conversion/qwen.py b/conversion/qwen.py index 6949ab979c8..4b86404262a 100644 --- a/conversion/qwen.py +++ b/conversion/qwen.py @@ -547,6 +547,8 @@ class _Qwen35MtpMixin: gguf_writer: gguf.GGUFWriter block_count: int tensor_map: gguf.TensorNameMap + no_mtp: bool + mtp_only: bool def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)