Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
2c9a408
spec : refactor
ggerganov May 7, 2026
befc7ef
spec : drop support for incompatible vocabs
ggerganov May 7, 2026
4550f0f
spec : update common_speculative_init()
ggerganov May 7, 2026
77269ad
cont : pass seq_id
ggerganov May 7, 2026
8a50f6f
cont : dedup ctx_seq_rm_type
ggerganov May 7, 2026
c97dc36
server : sketch the ctx_dft decode loop
ggerganov May 7, 2026
11fd5e7
server : draft prompt cache and checkpoints
ggerganov May 7, 2026
1afee5b
server : improve ctx names
ggerganov May 7, 2026
de35b12
server, spec : transition to unified spec context
ggerganov May 7, 2026
08c8012
cont : sync main and drft contexts
ggerganov May 7, 2026
c7facb0
cont : async drft eval when possible
ggerganov May 7, 2026
0239f4c
cont : handle non-ckpt models
ggerganov May 7, 2026
ae6703f
cont : pass correct n_past for drafting
ggerganov May 7, 2026
7e118cd
cont : process images throught the draft context
ggerganov May 7, 2026
8be14e4
spec : handle draft running out of context
ggerganov May 8, 2026
6a4b05a
server : fix mtmd draft processing
ggerganov May 8, 2026
12c7cfb
server : fix URL for draft model
ggerganov May 8, 2026
233d1ae
server : add comment
ggerganov May 8, 2026
3b1a8df
server : clean-up + dry
ggerganov May 8, 2026
e5b1401
speculative-simple : update
ggerganov May 8, 2026
161eae0
spec : fix n_past type
ggerganov May 8, 2026
1dbc054
server : fix slot ctx_drft ptr
ggerganov May 8, 2026
778f9e2
tools : update readme
ggerganov May 8, 2026
efa2f8e
naming : improve consistency
ggerganov May 8, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 0 additions & 18 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -622,10 +622,6 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
for (auto & seq_breaker : params.sampling.dry_sequence_breakers) {
string_process_escapes(seq_breaker);
}
for (auto & pair : params.speculative.draft.replacements) {
string_process_escapes(pair.first);
string_process_escapes(pair.second);
}
}

if (!params.kv_overrides.empty()) {
Expand Down Expand Up @@ -3518,13 +3514,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.speculative.draft.p_min = std::stof(value);
}
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_P_MIN"));
add_opt(common_arg(
{"--spec-draft-ctx-size", "-cd", "--ctx-size-draft"}, "N",
string_format("size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.speculative.draft.n_ctx),
[](common_params & params, int value) {
params.speculative.draft.n_ctx = value;
}
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_CTX_SIZE"));
add_opt(common_arg(
{"--spec-draft-device", "-devd", "--device-draft"}, "<dev1,dev2,..>",
"comma-separated list of devices to use for offloading the draft model (none = don't offload)\n"
Expand Down Expand Up @@ -3560,13 +3549,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.speculative.draft.mparams.path = value;
}
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_MODEL"));
add_opt(common_arg(
{"--spec-draft-replace", "--spec-replace"}, "TARGET", "DRAFT",
"translate the string in TARGET into DRAFT if the draft model and main model are not compatible",
[](common_params & params, const std::string & tgt, const std::string & dft) {
params.speculative.draft.replacements.push_back({ tgt, dft });
}
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
add_opt(common_arg(
{"--spec-type"}, "[none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]",
string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n",
Expand Down
101 changes: 100 additions & 1 deletion common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1422,7 +1422,7 @@ common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) {

// try to remove the last tokens
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__);
LOG_WRN("%s: the context does not support partial sequence removal\n", __func__);
res = COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
goto done;
}
Expand Down Expand Up @@ -1960,3 +1960,102 @@ bool common_prompt_batch_decode(

return true;
}

size_t common_prompt_checkpoint::size() const {
return data_tgt.size() + data_dft.size();
}

bool common_prompt_checkpoint::empty() const {
return data_tgt.empty();
}

void common_prompt_checkpoint::clear() {
n_tokens = 0;

pos_min = 0;
pos_max = 0;

data_tgt.clear();
data_dft.clear();
}

void common_prompt_checkpoint::update_pos(
int64_t n_tokens,
llama_pos pos_min,
llama_pos pos_max) {
this->n_tokens = n_tokens;
this->pos_min = pos_min;
this->pos_max = pos_max;
}

void common_prompt_checkpoint::update_tgt(
llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags) {
if (ctx == nullptr) {
return;
}

const size_t ckpt_size = llama_state_seq_get_size_ext(ctx, seq_id, flags);

data_tgt.resize(ckpt_size);

const size_t n = llama_state_seq_get_data_ext(ctx, data_tgt.data(), ckpt_size, seq_id, flags);
if (n != ckpt_size) {
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size, n);
}
}

void common_prompt_checkpoint::update_dft(
llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags) {
if (ctx == nullptr) {
return;
}

const size_t ckpt_size = llama_state_seq_get_size_ext(ctx, seq_id, flags);

data_dft.resize(ckpt_size);

const size_t n = llama_state_seq_get_data_ext(ctx, data_dft.data(), ckpt_size, seq_id, flags);
if (n != ckpt_size) {
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size, n);
}
}

void common_prompt_checkpoint::load_tgt(
llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags) const {
if (ctx == nullptr) {
return;
}

if (data_tgt.empty()) {
return;
}

const size_t n = llama_state_seq_set_data_ext(ctx, data_tgt.data(), data_tgt.size(), seq_id, flags);
if (n != data_tgt.size()) {
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", data_tgt.size(), n);
}
}

void common_prompt_checkpoint::load_dft(
llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags) const {
if (ctx == nullptr) {
return;
}

if (data_dft.empty()) {
return;
}

const size_t n = llama_state_seq_set_data_ext(ctx, data_dft.data(), data_dft.size(), seq_id, flags);
if (n != data_dft.size()) {
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", data_dft.size(), n);
}
}
51 changes: 46 additions & 5 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -307,11 +307,9 @@ struct common_params_speculative_draft {

common_params_model mparams;

llama_model * model = nullptr; // a llama_model that can be shared by multiple speculative contexts
llama_context * ctx_tgt = nullptr;
llama_context * ctx_dft = nullptr;

llama_context_params cparams; // these are the parameters for the draft llama_context

int32_t n_ctx = 0; // draft context size
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)

ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
Expand All @@ -322,7 +320,6 @@ struct common_params_speculative_draft {

std::vector<ggml_backend_dev_t> devices; // devices to use for offloading

std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
};

Expand Down Expand Up @@ -1026,3 +1023,47 @@ ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std

// "adamw" or "sgd" (case insensitive)
enum ggml_opt_optimizer_type common_opt_get_optimizer(const char *);

//
// prompt utils
//

struct common_prompt_checkpoint {
int64_t n_tokens;

llama_pos pos_min;
llama_pos pos_max;

std::vector<uint8_t> data_tgt;
std::vector<uint8_t> data_dft;

size_t size() const;

bool empty() const;
void clear();

void update_pos(
int64_t n_tokens,
llama_pos pos_min,
llama_pos pos_max);

void update_tgt(
llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags);

void update_dft(
llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags);

void load_tgt(
llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags) const;

void load_dft(
llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags) const;
};
Loading
Loading