Skip to content

llama : extend embeddings API#22728

Draft
ggerganov wants to merge 1 commit intomasterfrom
gg/llama-extract-embeddings
Draft

llama : extend embeddings API#22728
ggerganov wants to merge 1 commit intomasterfrom
gg/llama-extract-embeddings

Conversation

@ggerganov
Copy link
Copy Markdown
Member

Overview

Preparing some base functionality needed for extracting embeddings from different stages of the inference. This is needed to support speculative decoding methods such as Eagle3, MTP, etc.

  • Layer input embeddings extraction
  • Token embedding tensor replacement
  • TBD

Additional information

TBD - still figuring out what's needed

Requirements

@github-actions github-actions Bot added the model Model specific label May 5, 2026
@am17an
Copy link
Copy Markdown
Contributor

am17an commented May 6, 2026

Let me add my comment here since that PR is getting a bit crowded.

IMO the only blocker to not run this inside the same llama graph was that kv-cache becomes weird to handle. Perhaps we can restructure the kv-cache such that there is an auxiliary cache for these "speculators", then everything becomes kind of already working from a sync perspective. Or I am maybe missing something

@am17an
Copy link
Copy Markdown
Contributor

am17an commented May 6, 2026

It looks like for Gemma4 MTP, the MTP head actually attends to the target's KV cache, something to keep in mind

@ggerganov
Copy link
Copy Markdown
Member Author

run this inside the same llama graph

I don't think it is feasible to have the main and MTP graphs combined. Having the MTP graph (and any other speculative decoding graph) in a separate context has many advantages:

  • It has a separate backend scheduler
  • We have finer control on which devices to place the drafter
  • We can work with draft models that are separate from the main model
  • Multi-sequence drafting is much easier to conceptualize
  • etc.

There are many different variants of speculative decoding and more will appear in the future. We cannot stuff all this logic inside the llama_context. I think the current common/speculative foundation is overall good. We have to extend it to manage the prompt embeddings. And for that we first need a mechanism to extract them.

For the Gemma4 MTP, I would try to create a memory-less llama_context and assign the llama_memory from the target context to it so it can use it directly. After drafting, we would need to wipe the drafted tokens from the memory with seq_rm to restore the memory state for the target model.

@am17an
Copy link
Copy Markdown
Contributor

am17an commented May 6, 2026

Okay I agree that makes sense, stuff everything into llama_context is not good. So what we need is some sort of checkpoint (which is persistent to make it work prompt_cache use-cases) after each ubatch that the speculator can hook onto and run a prompt batch itself, if I understand this correctly.

@ggerganov
Copy link
Copy Markdown
Member Author

ggerganov commented May 6, 2026

The current speculative API basically assumes that the only input for the speculator is the prompt tokens:

// sample up to n_draft tokens and add them to the batch using the draft model
llama_tokens common_speculative_draft(
                     common_speculative * spec,
        const common_params_speculative & params,
                     const llama_tokens & prompt,         <--- here
                            llama_token   id_last);

For draft-model based speculative decoding (and also for Gemma4 MTP) this is true. But for standard MTP, Eagle3, DFlash - this is not true. We need to also pass the target-model embeddings (pre-output norm for MTP and intermediate layer-input embeddings for Eagle3).

I think we probably need to wrap this in a helper struct, something like this:

struct common_tokens {
    llama_tokens prompt;

    // optional embeddings collection from the target model
    // TODO: some generic way to encompass various embeddings. not sure exactly how
    // note: can become large in some cases
    llama_embeddings embd;
};

llama_tokens common_speculative_draft(
                     common_speculative * spec,
        const common_params_speculative & params,
                    const common_tokens & prompt,
                            llama_token   id_last

This way, similar to the draft-model based implementation that we currently have, we can do the prefix reuse and prompt reprocessing of the MTP context.

The server logic would need to be updated to manage common_tokens instead of just llama_tokens and extract the embeddings during prompt processing when needed.

@am17an
Copy link
Copy Markdown
Contributor

am17an commented May 6, 2026

@ggerganov this makes sense to me, I can move all the stuff out of llama-context and move into the common speculative class. This also means that the embedding is shared and not copied if they are on the same backend right?

@am17an
Copy link
Copy Markdown
Contributor

am17an commented May 6, 2026

Actually it doesn't seem like this would address this particular problem of sharing the tensor between two llama contexts, but at least we can do a device to device copy

@am17an
Copy link
Copy Markdown
Contributor

am17an commented May 6, 2026

   // optional embeddings collection from the target model
    // TODO: some generic way to encompass various embeddings. not sure exactly how
    // note: can become large in some cases
   // llama_embeddings embd;

It would be nice if the speculator can mark the layer hidden states as required using some API. For MTP it would last hidden layer pre-norm and for eagle3 it would be a bunch of other layers as well. Then it can get back exactly those embeddings back in common_tokens

@ggerganov
Copy link
Copy Markdown
Member Author

One issue the I realized now is that keeping all of the embeddings for the entire prompt in host memory would not be a good idea - slow, lots of memory. We have to pass them in batches (not ubatches) to the speculative context immediately. I am thinking to do that we need to first refactor the speculative contexts to have a single shared multi-sequence llama context (currently we have a llama context per slot). Doing so would allow us to have a single speculative llama context that "mirrors" the main server llama context. We would then decode the same batches, make the same checkpoints, etc. We would also be able to draft for multiple sequences in parallel.

I think we have to do this initial refactor first using the existing draft-model based speculative decoding.

@ruixiang63
Copy link
Copy Markdown

ruixiang63 commented May 6, 2026

Thanks for the proposal and discussion here. I have a few observations on how this generalizes to EAGLE3/DFlash vs. MTP:

  1. llama_model_set_tok_embd for token embedding tensor replacement cleanly handles MTP's case, where the extracted hidden is fed in as the token embedding. But EAGLE3/DFlash decoders take two inputs: the regular token embedding for the current draft token, and a separate embeddings tensor that comes from an fc-compressed concatenation of multiple target hidden states (via EAGLE3/DFlash encoder).

  2. The current EAGLE3/DFlash PR ends up with three llama contexts: ctx_tgt, ctx_dft_enc, and ctx_dft_dec, because draft encoder and decoder are two distinct computation phases with different lifecycles (draft encoder runs once per target verification to compress hidden states, decoder runs autoregressively N times to generate draft tokens (only for eagle3, dfalsh run one time as well) ). MTP avoids this entirely (single draft context) because its hidden-as-embed mapping is 1:1 and it doesn't need draft encoder to compress extracted target hidden states.

    I am thinking to do that we need to first refactor the speculative contexts to have a single shared multi-sequence llama context (currently we have a llama context per slot). Doing so would allow us to have a single speculative llama context that "mirrors" the main server llama context. We would then decode the same batches, make the same checkpoints, etc. We would also be able to draft for multiple sequences in parallel.

    Not sure if a single speculative llama context can work for eagle3/DFlash due to different lifecycles of draft encoder and decoder. And there's no way to express "decoder takes both a token input and a auxiliary embedding input" now, it seems struct common_tokens wrapper is designed for this, right?

  3. One issue the I realized now is that keeping all of the embeddings for the entire prompt in host memory would not be a good idea - slow, lots of memory.

    +1 on avoiding host round-trips. Whatever the API ends up looking like, I agree the data path between target output and drafter input (For eagle3/DFlash: we need to avoid draft encoder --> draft decoder copy as well) should be device-to-device copy or shared buffer.

  4. For MTP it would last hidden layer pre-norm and for eagle3 it would be a bunch of other layers as well. Then it can get back exactly those embeddings back in common_tokens

    right, this is another difference of MTP vs. DFlash/Eagle3. Eagle3 seems only extract 3 layers features from target model, DFlash extract 5 in most cases but I saw in their new model they extract 6 layers. So it would great to make extraction API flexiable. And these extracted layers index will come from the draft GGUF model directly.

So my feeling about path forward seems to be:

  1. this PR handles the target hidden states extraction side
  2. we need an auxiliary-embd input API to satisify EAGLE3 and DFlash (token input and auxiliary embedding input) needs
  3. for eagle3 and dfalsh: collapse draft encoder + decoder into one draft context seems not possible to me
  4. a backend-level shared/copy embedding handoff to avoid host round-trip
  5. as @ggerganov mentioned before, "be able to draft for multiple sequences in parallel" and "handling batches (not ubatches) to avoid overwrite" also important for these API design

@ggerganov
Copy link
Copy Markdown
Member Author

Actually it doesn't seem like this would address this particular problem of sharing the tensor between two llama contexts, but at least we can do a device to device copy

Yes, I am not sure about that yet. Worst case is we still have to go through host memory, but at least we have to transfer the target embeddings just for the current batch. Though, likely we can do better as you mention.

The current EAGLE3/DFlash PR ends up with three llama contexts: ctx_tgt, ctx_dft_enc, and ctx_dft_dec, because draft encoder and decoder are two distinct computation phases with different lifecycles (draft encoder runs once per target verification to compress hidden states, decoder runs autoregressively N times to generate draft tokens (only for eagle3, dfalsh run one time as well) ). MTP avoids this entirely (single draft context) because its hidden-as-embed mapping is 1:1 and it doesn't need draft encoder to compress extracted target hidden states.

The new Gemma4 MTP is similar in that regard - it has an "encoder" for compressing the target states. In general, I think we don't need separate ctx_dft_enc and ctx_dft_dec. We can have a single enc-dec llama_context and call both llama_encode and llama_decode on it.

Not sure if a single speculative llama context can work for eagle3/DFlash due to different lifecycles of draft encoder and decoder.

Why would they have different lifecycles?

@ggerganov
Copy link
Copy Markdown
Member Author

I think I'll focus and try to do the refactor of the speculative context to use a single llama_context that mirrors the main server llama_context. For that part, we don't need the new embeddings API yet - we can validate that the logic is sound by using the draft-model based speculative decoding. If it works out, then we can proceed with extending the embeddings API and hooking the target embeddings as necessary in order to support the rest of the speculative decoding methods.

@ruixiang63
Copy link
Copy Markdown

In general, I think we don't need separate ctx_dft_enc and ctx_dft_dec. We can have a single enc-dec llama_context and call both llama_encode and llama_decode on it.

Got it, that resolves the three-context concern cleanly.

Why would they have different lifecycles?

I meant in eagle3 the encoder fires once per target verification cycle, while the decoder is invoked many times autoregressively during draft generation. A single enc-dec llama_context should handle this fine since llama_encode and llama_decode can be called independently at their own cadences.

I think I'll focus and try to do the refactor of the speculative context to use a single llama_context that mirrors the main server llama_context. For that part, we don't need the new embeddings API yet - we can validate that the logic is sound by using the draft-model based speculative decoding.

Sounds good to me. It would be great to test eagle3 for a single enc-dec llama_context case as well, once it works, DFlash and Gemma4 MTP should also work.

@ggerganov
Copy link
Copy Markdown
Member Author

Yes, to clarify the way I think about the encoder part of the draft context is that its purpose is "to transform" (usually meaning "to compress") the target embeddings to draft embeddings. Technically, we can think as if the draft context always has an encoder, but sometimes (like in the regular MTP case) it is an identity operator, so we skip it.

Sounds good to me. It would be great to test eagle3 for a single enc-dec llama_context case as well, once it works, DFlash and Gemma4 MTP should also work.

Yeah, I think this refactor is going to be a big change and still some things are a bit fuzzy to me. But I feel making it work would be 90% of what we need to be able to support all methods cleanly. In any case, I will open a draft PR when I start working on it and we can keep track of the progress.

@ruixiang63
Copy link
Copy Markdown

Yes, to clarify the way I think about the encoder part of the draft context is that its purpose is "to transform" (usually meaning "to compress") the target embeddings to draft embeddings. Technically, we can think as if the draft context always has an encoder, but sometimes (like in the regular MTP case) it is an identity operator, so we skip it.

100% correct. Sounds great.

Yeah, I think this refactor is going to be a big change and still some things are a bit fuzzy to me. But I feel making it work would be 90% of what we need to be able to support all methods cleanly. In any case, I will open a draft PR when I start working on it and we can keep track of the progress.

Makes sense. Thanks for that!

@am17an
Copy link
Copy Markdown
Contributor

am17an commented May 6, 2026

Yeah, I think this refactor is going to be a big change and still some things are a bit fuzzy to me.

Yes, although without this refactor supporting speculators/drafters with multi-seq will be extremely messy. This is a good way to keep the changes contained inside common_speculative. Performance wise PP takes a hit using host side copies but that is unrelated to the refactor and can be optimized later.

@ruixiang63
Copy link
Copy Markdown

I looked a bit deeper into Gemma4 MTP. This may not be relevant to the current refactoring, but here are a few observations:

  • It still uses activations from the last layer of the target model, same as other MTP approaches. These activations are concatenated with the token embeddings and then down-projected to the drafter model’s hidden dimension. This part is similar to Eagle3.

  • The draft model cross-attends to the target model’s KV cache instead of building its own. This is somewhat similar to DFlash’s KV-cache injection mechanism. However, DFlash performs cross-attention over compressed features extracted from the target model, whereas Gemma4 MTP directly uses the KV cache from the target model’s last layers. llama.cpp does not currently support this mechanism very well.

  • The LM head uses a sparse decoding technique to identify the most likely token clusters to predict. The general idea is similar to reducing the draft model’s vocabulary size. However, Eagle3 trains a reduced embedding layer from scratch, while Gemma4 uses clustering and mapping to reduce computation overhead.

@am17an
Copy link
Copy Markdown
Contributor

am17an commented May 7, 2026

The LM head uses a sparse decoding technique to identify the most likely token clusters to predict. The general idea is similar to reducing the draft model’s vocabulary size. However, Eagle3 trains a reduced embedding layer from scratch, while Gemma4 uses clustering and mapping to reduce computation overhead.

This is not strictly required, it just is a way to get faster logits - I guess because of the large gemma vocab (~262k tokens), so the current MTP class would work if we can find a way to hook up the drafter's kv cache to the target's

@ggerganov
Copy link
Copy Markdown
Member Author

It still uses activations from the last layer of the target model, same as other MTP approaches.

Ah yes, I somehow thought it didn't. Thanks.

Started draft PR here: #22787

@ggerganov
Copy link
Copy Markdown
Member Author

ggerganov commented May 7, 2026

A question arises - consider the following:

# target context with processed tokens
ABCDEFG

# draft context now generates a draft (lowercase)
ABCDEFGhijk

# uppercase - tokens "infused" with target embeddings
# lowercase - tokens generated without target embeddings

Let's say the draft hijk gets accepted fully. Note that at this point, the generated draft is completely generated from the draft context, so it is not "infused" with the target model embeddings.

Do we need to re-evaluate the hijk tokens again using the draft context + the target model embeddings (which we get from the acceptance)? I think the answer is "yes", but it's something to verify. On master we don't re-evaluate it, and in #22787 I've fixed this.

@am17an
Copy link
Copy Markdown
Contributor

am17an commented May 8, 2026

Nice find! I did a bit of search it looks like this is similar to this bug vllm-project/vllm#14649? If so, we would need an extra verify pass which makes the draft up-to-date?

@ggerganov
Copy link
Copy Markdown
Member Author

It still uses activations from the last layer of the target model, same as other MTP approaches.

Regarding Gemma 4 MTP - something still does not make sense. If the KV cache of the target model is reused by the draft context, then there is no need to infuse the draft context with target embeddings for the prompt. I think the target embeddings are only needed by the first token in the draft (i.e. id_last). So it's not the same as the other MTP.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

model Model specific

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants