Skip to content

Commit 4b460e7

Browse files
fix(server): fix flash batch filtering (#220)
1 parent 1ffea36 commit 4b460e7

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

server/text_generation_server/models/flash_causal_lm.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,9 +188,10 @@ def filter(self, requests: List[generate_pb2.Request]) -> "FlashCausalLMBatch":
188188
position_ids.append(self.position_ids[idx])
189189
cu_seqlens.append(cumulative_length + request_input_length)
190190
max_seqlen = max(max_seqlen, request_input_length)
191+
# True index for past
192+
past_key_values.append(self.past_key_values[2 * idx])
193+
191194
if not single_request:
192-
# True index for past
193-
past_key_values.append(self.past_key_values[2 * idx])
194195
# Add one padding
195196
past_key_values.append(self.past_pad)
196197

@@ -209,7 +210,7 @@ def filter(self, requests: List[generate_pb2.Request]) -> "FlashCausalLMBatch":
209210
if single_request:
210211
# Preallocate tensor for bs = 1 case
211212
past_key_values = torch.nn.functional.pad(
212-
self.past_key_values[0],
213+
past_key_values[0],
213214
(
214215
0,
215216
0,

0 commit comments

Comments
 (0)