@@ -151,7 +151,7 @@ def forward(ctx: Any, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
151
151
# The forward computation will be parallelized along the batch dimension and the queries in blocks of size `BLOCK_Q`
152
152
grid = lambda meta : (triton .cdiv (q_seq_len , meta ["BLOCK_Q" ]), batch_size * k_heads * n_groups , 1 )
153
153
_attn_fwd [grid ](
154
- q , k , v , sm_scale , lse , o ,
154
+ q , k , v , sm_scale * 1.4426950408889634 , lse , o ,
155
155
n_groups = n_groups ,
156
156
q_seq_len = q_seq_len ,
157
157
kv_seq_len = kv_seq_len ,
@@ -201,10 +201,8 @@ def backward(ctx: Any, do: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, to
201
201
dk = torch .empty_like (k )
202
202
dv = torch .empty_like (v )
203
203
204
- # $\log_2 e$
205
- RCP_LN2 = 1.4426950408889634
206
204
# Precompute $\sigma (\log_2 e) K_j$
207
- k_scaled = k * (sm_scale * RCP_LN2 )
205
+ k_scaled = k * (sm_scale * 1.4426950408889634 )
208
206
# $D_i = P^T_{i:}dP_{i:} = do^T_io_i$
209
207
pdp = torch .empty_like (lse )
210
208
# We use fixed `BLOCK_Q` for backward pass on $D$
@@ -288,7 +286,7 @@ def _get_autotune_configs(inner_loop: str) -> list:
288
286
@triton .autotune (_get_autotune_configs (inner_loop = 'key' ),
289
287
key = ["q_seq_len" , "kv_seq_len" , "d_head" , "n_groups" , "is_causal" ])
290
288
@triton .jit
291
- def _attn_fwd (t_q , t_k , t_v , sm_scale , t_lse , t_o ,
289
+ def _attn_fwd (t_q , t_k , t_v , sm_scale_log2e , t_lse , t_o ,
292
290
n_groups : tl .constexpr ,
293
291
q_seq_len : tl .constexpr ,
294
292
kv_seq_len : tl .constexpr ,
@@ -359,11 +357,6 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
359
357
# Mask for $Q$ for the last block
360
358
i_mask = offs_i < q_seq_len
361
359
362
- # Precalculate $\frac{\sigma}{\log_2 e}$.
363
- #
364
- # We will be use this when calculating $S_{ij}$ so `S` will store $S_{ij} \log_2 e$ instead.
365
- sm_scale_log2e = sm_scale * 1.44269504
366
-
367
360
# Initialize $m_i$ and $l_i$. $m_i$ is initialized to $-\inf$ and $l_i$ to $1$. So in the first update,
368
361
# the effect of initial $l_i$ is $e^{m_i - m_{i}^{\text{new}}} l_i = 0$.
369
362
#
@@ -762,9 +755,6 @@ def _attn_bwd_dq(t_q, t_k, t_v, t_do,
762
755
BLOCK_Q : tl .constexpr ,
763
756
BLOCK_K : tl .constexpr ,
764
757
):
765
- # $\log_e 2$
766
- LN2 : tl .constexpr = 0.6931471824645996 # type: ignore
767
-
768
758
i = tl .program_id (0 ) * BLOCK_Q
769
759
z = tl .program_id (1 ) // n_groups
770
760
g = tl .program_id (1 ) % n_groups # TODO
@@ -859,7 +849,7 @@ def _attn_bwd_dq(t_q, t_k, t_v, t_do,
859
849
)
860
850
861
851
# `b_dq` stores $(\log_2 e)dQ$ so multiply by $\log_e 2$ to get $dQ$
862
- b_dq *= LN2
852
+ b_dq *= 0.6931471824645996
863
853
864
854
# Save $dQ$
865
855
tl .store (p_dq , b_dq .to (t_dq .type .element_ty ), boundary_check = (0 ,))
0 commit comments