Skip to content

Commit 3dd36b8

Browse files
committed
all comments
1 parent 73b9892 commit 3dd36b8

File tree

1 file changed

+4
-14
lines changed

1 file changed

+4
-14
lines changed

labml_nn/transformers/flash/__init__.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def forward(ctx: Any, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
151151
# The forward computation will be parallelized along the batch dimension and the queries in blocks of size `BLOCK_Q`
152152
grid = lambda meta: (triton.cdiv(q_seq_len, meta["BLOCK_Q"]), batch_size * k_heads * n_groups, 1)
153153
_attn_fwd[grid](
154-
q, k, v, sm_scale, lse, o,
154+
q, k, v, sm_scale * 1.4426950408889634, lse, o,
155155
n_groups=n_groups,
156156
q_seq_len=q_seq_len,
157157
kv_seq_len=kv_seq_len,
@@ -201,10 +201,8 @@ def backward(ctx: Any, do: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, to
201201
dk = torch.empty_like(k)
202202
dv = torch.empty_like(v)
203203

204-
# $\log_2 e$
205-
RCP_LN2 = 1.4426950408889634
206204
# Precompute $\sigma (\log_2 e) K_j$
207-
k_scaled = k * (sm_scale * RCP_LN2)
205+
k_scaled = k * (sm_scale * 1.4426950408889634)
208206
# $D_i = P^T_{i:}dP_{i:} = do^T_io_i$
209207
pdp = torch.empty_like(lse)
210208
# We use fixed `BLOCK_Q` for backward pass on $D$
@@ -288,7 +286,7 @@ def _get_autotune_configs(inner_loop: str) -> list:
288286
@triton.autotune(_get_autotune_configs(inner_loop='key'),
289287
key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
290288
@triton.jit
291-
def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
289+
def _attn_fwd(t_q, t_k, t_v, sm_scale_log2e, t_lse, t_o,
292290
n_groups: tl.constexpr,
293291
q_seq_len: tl.constexpr,
294292
kv_seq_len: tl.constexpr,
@@ -359,11 +357,6 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
359357
# Mask for $Q$ for the last block
360358
i_mask = offs_i < q_seq_len
361359

362-
# Precalculate $\frac{\sigma}{\log_2 e}$.
363-
#
364-
# We will be use this when calculating $S_{ij}$ so `S` will store $S_{ij} \log_2 e$ instead.
365-
sm_scale_log2e = sm_scale * 1.44269504
366-
367360
# Initialize $m_i$ and $l_i$. $m_i$ is initialized to $-\inf$ and $l_i$ to $1$. So in the first update,
368361
# the effect of initial $l_i$ is $e^{m_i - m_{i}^{\text{new}}} l_i = 0$.
369362
#
@@ -762,9 +755,6 @@ def _attn_bwd_dq(t_q, t_k, t_v, t_do,
762755
BLOCK_Q: tl.constexpr,
763756
BLOCK_K: tl.constexpr,
764757
):
765-
# $\log_e 2$
766-
LN2: tl.constexpr = 0.6931471824645996 # type: ignore
767-
768758
i = tl.program_id(0) * BLOCK_Q
769759
z = tl.program_id(1) // n_groups
770760
g = tl.program_id(1) % n_groups # TODO
@@ -859,7 +849,7 @@ def _attn_bwd_dq(t_q, t_k, t_v, t_do,
859849
)
860850

861851
# `b_dq` stores $(\log_2 e)dQ$ so multiply by $\log_e 2$ to get $dQ$
862-
b_dq *= LN2
852+
b_dq *= 0.6931471824645996
863853

864854
# Save $dQ$
865855
tl.store(p_dq, b_dq.to(t_dq.type.element_ty), boundary_check=(0,))

0 commit comments

Comments
 (0)