Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/20250422-new-kernel-deep-dive.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ Our solution involves an additional mathematical transformation beyond FlashAtte
0. Maintain a running max $m$ (initialized to $-\infty$, shared between the two warpgroups) and output matrices $\vec o_L, \vec o_R$ (initialized to 0).
1. [0] Compute $`\vec p_0 = \vec q K_0^\intercal / qk\_scale`$.
2. [1] Compute $`\vec p_1 = \vec q K_1^\intercal / qk\_scale`$.
3. [0] Compute $mp_0 = \max(\vec p_0)$, $`m\_new_0 = \max(m, mp_0)`$, and $`scale_0 = \exp(m\_new_0 - m)`$. Update $`m \gets m\_new_0`$.
3. [0] Compute $mp_0 = \max(\vec p_0)$, $`m\_new_0 = \max(m, mp_0)`$, and $`scale_0 = \exp(m - m\_new_0)`$. Update $`m \gets m\_new_0`$.
4. [0] Perform softmax on $\vec p_0$: $`\vec p_0 \gets \exp(\vec p_0 - m\_new_0)`$.
5. [0] Update $\vec o_L \gets \vec o_L \cdot scale_0 + \vec p_0 V_{0L}$.
6. [1] Compute $mp_1 = \max(\vec p_1)$, $`m\_new_1 = \max(m, mp_1)`$, and $`scale_1 = \exp(m\_new_1 - m)`$. Update $`m \gets m\_new_1`$.
6. [1] Compute $mp_1 = \max(\vec p_1)$, $`m\_new_1 = \max(m, mp_1)`$, and $`scale_1 = \exp(m - m\_new_1)`$. Update $`m \gets m\_new_1`$.
7. [1] Perform softmax on $\vec p_1$: $`\vec p_1 \gets \exp(\vec p_1 - m\_new_1)`$.
8. [1] Update $\vec o_R \gets \vec o_R \cdot (scale_0 \cdot scale_1) + \vec p_1 V_{1R}$.
9. [0] Update $\vec p_0 \gets \vec p_0 \cdot scale_1$.
Expand Down