Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
9b996f0
spec: support MTP
am17an May 11, 2026
80e1f3c
fix batch size
am17an May 11, 2026
8d16341
rename files
am17an May 11, 2026
5e1965d
cont : simplify (#7)
ggerganov May 11, 2026
89f6e0d
MTP: clean-up (#9)
am17an May 13, 2026
7ea1289
mtp -> draft-mtp
am17an May 13, 2026
9243e50
remove unused llama_arch
am17an May 13, 2026
23ae80a
add need_embd in speculative
am17an May 13, 2026
a5b3e98
llama: allow partial seq_rm for GDN models for speculative decoding
am17an Apr 25, 2026
3aa9ddc
fix pending state
am17an May 14, 2026
2ef737a
vulkan: add GDN partial rollback
am17an May 14, 2026
d7443da
meta: extend check to axis 1
am17an May 14, 2026
19be81c
metal: add GDN partial rollback
ggerganov May 14, 2026
d0759f0
delta_net_base: use ggml_pad instead of new_tensor
am17an May 14, 2026
78a78ae
review: add need_rs_seq
am17an May 14, 2026
611f422
review: rename part_bounded to n_rs
am17an May 14, 2026
df4cd32
review: deslop comments
am17an May 14, 2026
9674711
review: rename, add asserts
am17an May 14, 2026
7b54ac5
server : adjust checkpoint logic (#11)
ggerganov May 14, 2026
749a0b2
server-context: fix early exit
am17an May 14, 2026
d42d25d
spec : fix compatibility with n-gram and add TODOs (#13)
ggerganov May 15, 2026
cddbb7f
llama-memory: enable checkpointing with partial rollback
am17an May 15, 2026
6ef79f7
cont: add test-case for loading into a dirty ctx
am17an May 16, 2026
0f6f0d6
llama-memory-recurrent: clear rs_idx in clear
am17an May 16, 2026
37a479f
download: fix mtp path
am17an May 16, 2026
8e9a07d
llama-arch: fix enorm op
am17an May 16, 2026
5a818cd
docs: update docs
am17an May 16, 2026
2dff7ff
conversion: fix type annotations
am17an May 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -337,11 +337,15 @@ static bool common_params_handle_remote_preset(common_params & params, llama_exa
struct handle_model_result {
bool found_mmproj = false;
common_params_model mmproj;

bool found_mtp = false;
common_params_model mtp;
};

static handle_model_result common_params_handle_model(struct common_params_model & model,
const std::string & bearer_token,
bool offline) {
bool offline,
bool search_mtp = false) {
handle_model_result result;
Comment on lines 345 to 349
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO @ngxson make the function to accept opts as argument


if (!model.docker_repo.empty()) {
Expand All @@ -356,7 +360,7 @@ static handle_model_result common_params_handle_model(struct common_params_model
common_download_opts opts;
opts.bearer_token = bearer_token;
opts.offline = offline;
auto download_result = common_download_model(model, opts, true);
auto download_result = common_download_model(model, opts, true, search_mtp);

if (download_result.model_path.empty()) {
throw std::runtime_error("failed to download model from Hugging Face");
Expand All @@ -369,6 +373,11 @@ static handle_model_result common_params_handle_model(struct common_params_model
result.found_mmproj = true;
result.mmproj.path = download_result.mmproj_path;
}

if (!download_result.mtp_path.empty()) {
result.found_mtp = true;
result.mtp.path = download_result.mtp_path;
}
} else if (!model.url.empty()) {
if (model.path.empty()) {
auto f = string_split<std::string>(model.url, '#').front();
Expand Down Expand Up @@ -436,7 +445,11 @@ static bool parse_bool_value(const std::string & value) {
//

void common_params_handle_models(common_params & params, llama_example curr_ex) {
auto res = common_params_handle_model(params.model, params.hf_token, params.offline);
const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(),
params.speculative.types.end(),
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params.speculative.types.end();

auto res = common_params_handle_model(params.model, params.hf_token, params.offline, spec_type_draft_mtp);
if (params.no_mmproj) {
params.mmproj = {};
} else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) {
Expand All @@ -450,6 +463,14 @@ void common_params_handle_models(common_params & params, llama_example curr_ex)
break;
}
}
// when --spec-type mtp is set and no draft model was provided explicitly,
// fall back to the MTP head discovered alongside the -hf model
if (spec_type_draft_mtp && res.found_mtp &&
params.speculative.draft.mparams.path.empty() &&
params.speculative.draft.mparams.hf_repo.empty() &&
params.speculative.draft.mparams.url.empty()) {
params.speculative.draft.mparams.path = res.mtp.path;
}
common_params_handle_model(params.speculative.draft.mparams, params.hf_token, params.offline);
common_params_handle_model(params.vocoder.model, params.hf_token, params.offline);
}
Expand Down Expand Up @@ -3608,8 +3629,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("comma-separated list of types of speculative decoding to use (default: %s)\n",
common_speculative_type_name_str(params.speculative.types).c_str()),
[](common_params & params, const std::string & value) {
const auto enabled_types = string_split<std::string>(value, ',');
params.speculative.types = common_speculative_types_from_names(enabled_types);
const auto types_str = string_split<std::string>(value, ',');
auto types = common_speculative_types_from_names(types_str);
params.speculative.types.insert(params.speculative.types.end(), types.begin(), types.end());
}
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_TYPE"));
add_opt(common_arg(
Expand Down Expand Up @@ -4098,7 +4120,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"--spec-default"},
string_format("enable default speculative decoding config"),
[](common_params & params) {
params.speculative.types = { COMMON_SPECULATIVE_TYPE_NGRAM_MOD };
params.speculative.types.push_back(COMMON_SPECULATIVE_TYPE_NGRAM_MOD);
params.speculative.ngram_mod.n_match = 24;
params.speculative.ngram_mod.n_min = 48;
params.speculative.ngram_mod.n_max = 64;
Expand Down
56 changes: 56 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "log.h"
#include "llama.h"
#include "sampling.h"
#include "speculative.h"
#include "unicode.h"

#include <algorithm>
Expand Down Expand Up @@ -1247,6 +1248,29 @@ common_init_result::common_init_result(common_params & params) :
cparams.n_samplers = pimpl->samplers_seq_config.size();
}

// [TAG_RS_STATE_ROLLBACK_SUPPORT]
// TODO: ngram speculative methods require checkpointing in addition to partial RS rollback
// currently this is not supported. so we disable the partial rollback
if (cparams.n_rs_seq > 0 && (llama_model_is_recurrent(model) || llama_model_is_hybrid(model))) {
auto & types = params.speculative.types;

for (int i = 0; i < (int) types.size(); i++) {
if (types[i] == COMMON_SPECULATIVE_TYPE_NONE) {
continue;
}
if (types[i] == COMMON_SPECULATIVE_TYPE_DRAFT_MTP) {
continue;
}

cparams.n_rs_seq = 0;

LOG_WRN("%s: recurrent state rollback is not compatible with '%s' - disabling rollback support\n", __func__,
common_speculative_type_to_str(types[i]).c_str());

break;
}
}

llama_context * lctx = llama_init_from_model(model, cparams);
if (lctx == NULL) {
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
Expand Down Expand Up @@ -1435,6 +1459,12 @@ common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) {
goto done;
}

if (llama_n_rs_seq(ctx) > 0) {
LOG_INF("%s: the context supports bounded partial sequence removal\n", __func__);
res = COMMON_CONTEXT_SEQ_RM_TYPE_RS;
goto done;
}

// try to remove the last tokens
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
LOG_TRC("%s: the context does not support partial sequence removal\n", __func__);
Expand All @@ -1449,6 +1479,23 @@ common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) {
return res;
}

void common_context_seq_rm(llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
auto * mem = llama_get_memory(ctx);
if (!llama_memory_seq_rm(mem, seq_id, p0, p1)) {
GGML_ABORT("%s", string_format("failed to remove sequence %d with p0=%d, p1=%d\n", seq_id, p0, p1).c_str());
}
}

void common_context_seq_cp(llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
auto * mem = llama_get_memory(ctx);
llama_memory_seq_cp(mem, seq_id_src, seq_id_dst, p0, p1);
}

void common_context_seq_add(llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
auto * mem = llama_get_memory(ctx);
llama_memory_seq_add(mem, seq_id, p0, p1, delta);
}

void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora) {
std::vector<llama_adapter_lora *> loras;
std::vector<float> scales;
Expand Down Expand Up @@ -1505,6 +1552,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &

cparams.n_ctx = params.n_ctx;
cparams.n_seq_max = params.n_parallel;
cparams.n_rs_seq = params.speculative.need_n_rs_seq();
cparams.n_batch = params.n_batch;
cparams.n_ubatch = params.n_ubatch;
cparams.n_threads = params.cpuparams.n_threads;
Expand Down Expand Up @@ -2074,3 +2122,11 @@ void common_prompt_checkpoint::load_dft(
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", data_dft.size(), n);
}
}

void common_prompt_checkpoint::clear_tgt() {
data_tgt.clear();
}

void common_prompt_checkpoint::clear_dft() {
data_dft.clear();
}
26 changes: 22 additions & 4 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <string_view>
#include <vector>
#include <map>
#include <algorithm>

#if defined(_WIN32) && !defined(_WIN32_WINNT)
#define _WIN32_WINNT 0x0A00
Expand Down Expand Up @@ -159,6 +160,7 @@ enum common_speculative_type {
COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding
COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE, // standalone draft model speculative decoding
COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, // Eagle3 speculative decoding
COMMON_SPECULATIVE_TYPE_DRAFT_MTP, // Multi-token prediction
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding based on n-grams
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values
Expand Down Expand Up @@ -301,7 +303,7 @@ struct common_params_speculative_draft {
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding

float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
float p_min = 0.75f; // minimum speculative decoding probability (greedy) // TODO: change default to 0.0f

common_params_model mparams;

Expand Down Expand Up @@ -355,6 +357,14 @@ struct common_params_speculative {
bool has_dft() const {
return !draft.mparams.path.empty() || !draft.mparams.hf_repo.empty();
}

uint32_t need_n_rs_seq() const {
bool needs_rs_seq = std::any_of(types.begin(), types.end(), [&](auto t) {
return t == COMMON_SPECULATIVE_TYPE_DRAFT_MTP;
});

return needs_rs_seq ? draft.n_max : 0u;
}
};

struct common_params_vocoder {
Expand Down Expand Up @@ -884,15 +894,20 @@ std::string common_get_model_endpoint();
//

enum common_context_seq_rm_type {
COMMON_CONTEXT_SEQ_RM_TYPE_NO = 0, // seq_rm not supported (e.g. no memory module)
COMMON_CONTEXT_SEQ_RM_TYPE_PART = 1, // can seq_rm partial sequences
COMMON_CONTEXT_SEQ_RM_TYPE_FULL = 2, // can seq_rm full sequences only
COMMON_CONTEXT_SEQ_RM_TYPE_NO = 0, // seq_rm not supported (e.g. no memory module)
COMMON_CONTEXT_SEQ_RM_TYPE_PART = 1, // can seq_rm partial sequences
COMMON_CONTEXT_SEQ_RM_TYPE_FULL = 2, // can seq_rm full sequences only
COMMON_CONTEXT_SEQ_RM_TYPE_RS = 3, // can seq_rm partial sequences, bounded by n_rs_seq
};

// check if the llama_context can remove sequences
// note: clears the memory of the context
common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx);

// aborts execution on failure
void common_context_seq_rm (llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1);
void common_context_seq_add(llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta);
void common_context_seq_cp (llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1);

//
// Batch utils
Expand Down Expand Up @@ -1074,4 +1089,7 @@ struct common_prompt_checkpoint {
llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags) const;

void clear_tgt();
void clear_dft();
};
55 changes: 42 additions & 13 deletions common/download.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -566,8 +566,11 @@ static hf_cache::hf_files get_split_files(const hf_cache::hf_files & files,
return result;
}

static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files,
const std::string & model) {
// pick the best sibling GGUF whose filename contains `keyword` (e.g. "mmproj" / "mtp"),
// preferring deeper shared directory prefix with the model, then closest quantization
static hf_cache::hf_file find_best_sibling(const hf_cache::hf_files & files,
const std::string & model,
const std::string & keyword) {
hf_cache::hf_file best;
size_t best_depth = 0;
int best_diff = 0;
Expand All @@ -579,20 +582,20 @@ static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files,

for (const auto & f : files) {
if (!string_ends_with(f.path, ".gguf") ||
f.path.find("mmproj") == std::string::npos) {
f.path.find(keyword) == std::string::npos) {
continue;
}

auto mmproj_parts = string_split<std::string>(f.path, '/');
auto mmproj_dir = mmproj_parts.end() - 1;
auto sib_parts = string_split<std::string>(f.path, '/');
auto sib_dir = sib_parts.end() - 1;

auto [_, dir] = std::mismatch(model_parts.begin(), model_dir,
mmproj_parts.begin(), mmproj_dir);
if (dir != mmproj_dir) {
sib_parts.begin(), sib_dir);
if (dir != sib_dir) {
continue;
}

size_t depth = dir - mmproj_parts.begin();
size_t depth = dir - sib_parts.begin();
auto bits = extract_quant_bits(f.path);
auto diff = std::abs(bits - model_bits);

Expand All @@ -606,6 +609,16 @@ static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files,
return best;
}

static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files,
const std::string & model) {
return find_best_sibling(files, model, "mmproj");
}

static hf_cache::hf_file find_best_mtp(const hf_cache::hf_files & files,
const std::string & model) {
return find_best_sibling(files, model, "mtp-");
}

static bool gguf_filename_is_model(const std::string & filepath) {
if (!string_ends_with(filepath, ".gguf")) {
return false;
Expand All @@ -617,7 +630,8 @@ static bool gguf_filename_is_model(const std::string & filepath) {
}

return filename.find("mmproj") == std::string::npos &&
filename.find("imatrix") == std::string::npos;
filename.find("imatrix") == std::string::npos &&
filename.find("mtp-") == std::string::npos;
}

static hf_cache::hf_file find_best_model(const hf_cache::hf_files & files,
Expand Down Expand Up @@ -673,11 +687,13 @@ struct hf_plan {
hf_cache::hf_file primary;
hf_cache::hf_files model_files;
hf_cache::hf_file mmproj;
hf_cache::hf_file mtp;
};

static hf_plan get_hf_plan(const common_params_model & model,
const common_download_opts & opts,
bool download_mmproj) {
bool download_mmproj,
bool download_mtp) {
hf_plan plan;
hf_cache::hf_files all;

Expand Down Expand Up @@ -723,6 +739,10 @@ static hf_plan get_hf_plan(const common_params_model & model,
plan.mmproj = find_best_mmproj(all, primary.path);
}

if (download_mtp) {
plan.mtp = find_best_mtp(all, primary.path);
}

return plan;
}

Expand Down Expand Up @@ -756,21 +776,25 @@ static std::vector<download_task> get_url_tasks(const common_params_model & mode

common_download_model_result common_download_model(const common_params_model & model,
const common_download_opts & opts,
bool download_mmproj) {
bool download_mmproj,
bool download_mtp) {
common_download_model_result result;
std::vector<download_task> tasks;
hf_plan hf;

bool is_hf = !model.hf_repo.empty();

if (is_hf) {
hf = get_hf_plan(model, opts, download_mmproj);
hf = get_hf_plan(model, opts, download_mmproj, download_mtp);
for (const auto & f : hf.model_files) {
tasks.push_back({f.url, f.local_path});
}
if (!hf.mmproj.path.empty()) {
tasks.push_back({hf.mmproj.url, hf.mmproj.local_path});
}
if (!hf.mtp.path.empty()) {
tasks.push_back({hf.mtp.url, hf.mtp.local_path});
}
} else if (!model.url.empty()) {
tasks = get_url_tasks(model);
} else {
Expand Down Expand Up @@ -807,6 +831,10 @@ common_download_model_result common_download_model(const common_params_model &
if (!hf.mmproj.path.empty()) {
result.mmproj_path = hf_cache::finalize_file(hf.mmproj);
}

if (!hf.mtp.path.empty()) {
result.mtp_path = hf_cache::finalize_file(hf.mtp);
}
} else {
result.model_path = model.path;
}
Expand Down Expand Up @@ -946,7 +974,8 @@ std::vector<common_cached_model_info> common_list_cached_models() {
for (const auto & f : files) {
auto split = get_gguf_split_info(f.path);
if (split.index != 1 || split.tag.empty() ||
split.prefix.find("mmproj") != std::string::npos) {
split.prefix.find("mmproj") != std::string::npos ||
split.prefix.find("mtp-") != std::string::npos) {
continue;
}
if (seen.insert(f.repo_id + ":" + split.tag).second) {
Expand Down
Loading
Loading