Skip to content

Commit 08b7e4a

Browse files
fix(server): fix flash neox rotary embeddings (#150)
1 parent 610bb1f commit 08b7e4a

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

server/text_generation_server/models/flash_neox_modeling.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -319,12 +319,12 @@ def forward(
319319
layer_past[...] = qkv_rot[:, 1:]
320320

321321
# output
322-
attn_output = torch.empty_like(qkv[:, 0])
322+
attn_output = torch.empty_like(qkv_rot[:, 0])
323323
# flash attention
324324
flash_attn_cuda.fwd(
325-
qkv[:, 0],
326-
qkv[:, 1],
327-
qkv[:, 2],
325+
qkv_rot[:, 0],
326+
qkv_rot[:, 1],
327+
qkv_rot[:, 2],
328328
attn_output,
329329
cu_seqlens,
330330
cu_seqlens,

0 commit comments

Comments
 (0)