Conversation
|
Let me add my comment here since that PR is getting a bit crowded.
|
|
It looks like for Gemma4 MTP, the MTP head actually attends to the target's KV cache, something to keep in mind |
I don't think it is feasible to have the main and MTP graphs combined. Having the MTP graph (and any other speculative decoding graph) in a separate context has many advantages:
There are many different variants of speculative decoding and more will appear in the future. We cannot stuff all this logic inside the For the Gemma4 MTP, I would try to create a memory-less |
|
Okay I agree that makes sense, stuff everything into llama_context is not good. So what we need is some sort of checkpoint (which is persistent to make it work prompt_cache use-cases) after each ubatch that the speculator can hook onto and run a prompt batch itself, if I understand this correctly. |
|
The current speculative API basically assumes that the only input for the speculator is the prompt tokens: // sample up to n_draft tokens and add them to the batch using the draft model
llama_tokens common_speculative_draft(
common_speculative * spec,
const common_params_speculative & params,
const llama_tokens & prompt, <--- here
llama_token id_last);For draft-model based speculative decoding (and also for Gemma4 MTP) this is true. But for standard MTP, Eagle3, DFlash - this is not true. We need to also pass the target-model embeddings (pre-output norm for MTP and intermediate layer-input embeddings for Eagle3). I think we probably need to wrap this in a helper struct, something like this: struct common_tokens {
llama_tokens prompt;
// optional embeddings collection from the target model
// TODO: some generic way to encompass various embeddings. not sure exactly how
// note: can become large in some cases
llama_embeddings embd;
};
llama_tokens common_speculative_draft(
common_speculative * spec,
const common_params_speculative & params,
const common_tokens & prompt,
llama_token id_lastThis way, similar to the draft-model based implementation that we currently have, we can do the prefix reuse and prompt reprocessing of the MTP context. The server logic would need to be updated to manage |
|
@ggerganov this makes sense to me, I can move all the stuff out of llama-context and move into the common speculative class. This also means that the embedding is shared and not copied if they are on the same backend right? |
|
Actually it doesn't seem like this would address this particular problem of sharing the tensor between two llama contexts, but at least we can do a device to device copy |
It would be nice if the speculator can mark the layer hidden states as required using some API. For MTP it would last hidden layer pre-norm and for eagle3 it would be a bunch of other layers as well. Then it can get back exactly those embeddings back in |
|
One issue the I realized now is that keeping all of the embeddings for the entire prompt in host memory would not be a good idea - slow, lots of memory. We have to pass them in batches (not ubatches) to the speculative context immediately. I am thinking to do that we need to first refactor the speculative contexts to have a single shared multi-sequence llama context (currently we have a llama context per slot). Doing so would allow us to have a single speculative llama context that "mirrors" the main server llama context. We would then decode the same batches, make the same checkpoints, etc. We would also be able to draft for multiple sequences in parallel. I think we have to do this initial refactor first using the existing draft-model based speculative decoding. |
|
Thanks for the proposal and discussion here. I have a few observations on how this generalizes to EAGLE3/DFlash vs. MTP:
So my feeling about path forward seems to be:
|
Yes, I am not sure about that yet. Worst case is we still have to go through host memory, but at least we have to transfer the target embeddings just for the current batch. Though, likely we can do better as you mention.
The new Gemma4 MTP is similar in that regard - it has an "encoder" for compressing the target states. In general, I think we don't need separate
Why would they have different lifecycles? |
|
I think I'll focus and try to do the refactor of the speculative context to use a single |
Got it, that resolves the three-context concern cleanly.
I meant in eagle3 the encoder fires once per target verification cycle, while the decoder is invoked many times autoregressively during draft generation. A single enc-dec llama_context should handle this fine since llama_encode and llama_decode can be called independently at their own cadences.
Sounds good to me. It would be great to test eagle3 for a single enc-dec llama_context case as well, once it works, DFlash and Gemma4 MTP should also work. |
|
Yes, to clarify the way I think about the encoder part of the draft context is that its purpose is "to transform" (usually meaning "to compress") the target embeddings to draft embeddings. Technically, we can think as if the draft context always has an encoder, but sometimes (like in the regular MTP case) it is an identity operator, so we skip it.
Yeah, I think this refactor is going to be a big change and still some things are a bit fuzzy to me. But I feel making it work would be 90% of what we need to be able to support all methods cleanly. In any case, I will open a draft PR when I start working on it and we can keep track of the progress. |
100% correct. Sounds great.
Makes sense. Thanks for that! |
Yes, although without this refactor supporting speculators/drafters with multi-seq will be extremely messy. This is a good way to keep the changes contained inside common_speculative. Performance wise PP takes a hit using host side copies but that is unrelated to the refactor and can be optimized later. |
|
I looked a bit deeper into Gemma4 MTP. This may not be relevant to the current refactoring, but here are a few observations:
|
This is not strictly required, it just is a way to get faster logits - I guess because of the large gemma vocab (~262k tokens), so the current MTP class would work if we can find a way to hook up the drafter's kv cache to the target's |
Ah yes, I somehow thought it didn't. Thanks. Started draft PR here: #22787 |
|
A question arises - consider the following: # target context with processed tokens
ABCDEFG
# draft context now generates a draft (lowercase)
ABCDEFGhijk
# uppercase - tokens "infused" with target embeddings
# lowercase - tokens generated without target embeddingsLet's say the draft Do we need to re-evaluate the |
|
Nice find! I did a bit of search it looks like this is similar to this bug vllm-project/vllm#14649? If so, we would need an extra verify pass which makes the draft up-to-date? |
Regarding Gemma 4 MTP - something still does not make sense. If the KV cache of the target model is reused by the draft context, then there is no need to infuse the draft context with target embeddings for the prompt. I think the target embeddings are only needed by the first token in the draft (i.e. |
Overview
Preparing some base functionality needed for extracting embeddings from different stages of the inference. This is needed to support speculative decoding methods such as Eagle3, MTP, etc.
Additional information
TBD - still figuring out what's needed
Requirements