From 8e063f0dbc4531762c1c7966ba0c1b7f2a190410 Mon Sep 17 00:00:00 2001 From: tangpanyu Date: Thu, 23 Apr 2026 14:41:16 +0800 Subject: [PATCH] docs: fix sign of online softmax rescaling factors --- docs/20250422-new-kernel-deep-dive.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/20250422-new-kernel-deep-dive.md b/docs/20250422-new-kernel-deep-dive.md index da0b6dca..85086b63 100644 --- a/docs/20250422-new-kernel-deep-dive.md +++ b/docs/20250422-new-kernel-deep-dive.md @@ -27,10 +27,10 @@ Our solution involves an additional mathematical transformation beyond FlashAtte 0. Maintain a running max $m$ (initialized to $-\infty$, shared between the two warpgroups) and output matrices $\vec o_L, \vec o_R$ (initialized to 0). 1. [0] Compute $`\vec p_0 = \vec q K_0^\intercal / qk\_scale`$. 2. [1] Compute $`\vec p_1 = \vec q K_1^\intercal / qk\_scale`$. -3. [0] Compute $mp_0 = \max(\vec p_0)$, $`m\_new_0 = \max(m, mp_0)`$, and $`scale_0 = \exp(m\_new_0 - m)`$. Update $`m \gets m\_new_0`$. +3. [0] Compute $mp_0 = \max(\vec p_0)$, $`m\_new_0 = \max(m, mp_0)`$, and $`scale_0 = \exp(m - m\_new_0)`$. Update $`m \gets m\_new_0`$. 4. [0] Perform softmax on $\vec p_0$: $`\vec p_0 \gets \exp(\vec p_0 - m\_new_0)`$. 5. [0] Update $\vec o_L \gets \vec o_L \cdot scale_0 + \vec p_0 V_{0L}$. -6. [1] Compute $mp_1 = \max(\vec p_1)$, $`m\_new_1 = \max(m, mp_1)`$, and $`scale_1 = \exp(m\_new_1 - m)`$. Update $`m \gets m\_new_1`$. +6. [1] Compute $mp_1 = \max(\vec p_1)$, $`m\_new_1 = \max(m, mp_1)`$, and $`scale_1 = \exp(m - m\_new_1)`$. Update $`m \gets m\_new_1`$. 7. [1] Perform softmax on $\vec p_1$: $`\vec p_1 \gets \exp(\vec p_1 - m\_new_1)`$. 8. [1] Update $\vec o_R \gets \vec o_R \cdot (scale_0 \cdot scale_1) + \vec p_1 V_{1R}$. 9. [0] Update $\vec p_0 \gets \vec p_0 \cdot scale_1$.