Skip to content

Commit 0627983

Browse files
authored
[Gaudi] use pad_token_id to pad input id (#3268)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
1 parent 3752143 commit 0627983

File tree

4 files changed

+40
-16
lines changed

4 files changed

+40
-16
lines changed

Dockerfile_gaudi

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,6 @@ RUN cd server && \
9595
make gen-server && \
9696
pip install --no-deps -r requirements.txt && \
9797
bash ./dill-0.3.8-patch.sh && \
98-
pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \
99-
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
10098
pip install . --no-cache-dir
10199
RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git@bmax_fix
102100
RUN pip install compressed-tensors==0.9.1

backends/gaudi/server/text_generation_server/models/flash_causal_lm.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -975,7 +975,7 @@ def concatenate(
975975
valid_indices=None,
976976
)
977977

978-
def prepare_for_decode(self, dtype, use_contiguous_pa, bucketing_ctx):
978+
def prepare_for_decode(self, dtype, use_contiguous_pa, bucketing_ctx, pad_token_id):
979979
block_num = [length // BLOCK_SIZE + 1 for length in self.cache_lengths]
980980
block_tables = []
981981
for i, bt in enumerate(self.block_tables):
@@ -998,7 +998,7 @@ def prepare_for_decode(self, dtype, use_contiguous_pa, bucketing_ctx):
998998
bucketing_ctx,
999999
)
10001000
self.input_ids = F.pad(
1001-
self.input_ids, (0, padded_bs - self.input_ids.shape[0]), value=0
1001+
self.input_ids, (0, padded_bs - self.input_ids.shape[0]), value=pad_token_id
10021002
)
10031003

10041004
if self.position_ids.dim() == 2:
@@ -1040,7 +1040,7 @@ def prepare_for_decode(self, dtype, use_contiguous_pa, bucketing_ctx):
10401040
)
10411041

10421042
def prepare_for_prefill(
1043-
self, max_padded_input_len, max_padded_bs, max_total_tokens
1043+
self, max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id
10441044
):
10451045
# Prepare values if we need to continue prefilling
10461046
# Speculation must be ignored while we prefill even with chunking
@@ -1064,18 +1064,23 @@ def prepare_for_prefill(
10641064
for input_id in self.input_ids:
10651065
padded = self.max_input_length - len(input_id) + extra_pad
10661066
if padded > 0:
1067-
input_id = [0] * padded + input_id
1067+
input_id = [pad_token_id] * padded + input_id
10681068
input_ids.append(input_id)
10691069
input_ids_padded_length.append(padded)
10701070
input_ids = np.concatenate(input_ids, dtype=np.int64)
10711071
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
10721072
elif isinstance(self.input_ids, list):
10731073
input_ids = self.input_ids[0]
10741074
input_ids_padded_length.append(extra_pad)
1075-
input_ids = [0] * extra_pad + input_ids
1075+
input_ids = [pad_token_id] * extra_pad + input_ids
10761076
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
10771077
else:
1078-
input_ids = self.input_ids.new_zeros(max_padded_input_len * len(self))
1078+
input_ids = torch.full(
1079+
(max_padded_input_len * len(self),),
1080+
pad_token_id,
1081+
dtype=torch.int64,
1082+
device=self.input_ids.device,
1083+
)
10791084
src_pos = 0
10801085
for i in range(len(self)):
10811086
end_pos = (i + 1) * max_padded_input_len
@@ -1090,7 +1095,7 @@ def prepare_for_prefill(
10901095
self.input_ids = input_ids
10911096

10921097
self.input_ids = F.pad(
1093-
self.input_ids, (0, extra_pad_bs * max_padded_input_len), value=0
1098+
self.input_ids, (0, extra_pad_bs * max_padded_input_len), value=pad_token_id
10941099
)
10951100

10961101
self.input_lengths_tensor = torch.tensor(self.input_lengths, dtype=torch.int32)
@@ -1312,8 +1317,9 @@ def prepare_for_prefill(
13121317
self.prefill_next_token_indices = (
13131318
self.prefill_next_token_indices + input_ids_padded_length_tensor
13141319
)
1315-
all_input_ids_tensor = torch.zeros(
1320+
all_input_ids_tensor = torch.full(
13161321
(max_padded_bs, max(max_total_tokens, self.all_input_ids_tensor.shape[-1])),
1322+
pad_token_id,
13171323
dtype=torch.int64,
13181324
device="hpu",
13191325
)
@@ -1502,6 +1508,19 @@ def __init__(
15021508
)
15031509
self.skip_warmup = os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true"
15041510
self.max_seq_len_to_capture = 8192
1511+
if tokenizer.pad_token_id is None:
1512+
if config.pad_token_id is not None:
1513+
tokenizer.pad_token_id = config.pad_token_id
1514+
elif config.eos_token_id is not None:
1515+
tokenizer.pad_token_id = (
1516+
config.eos_token_id[0]
1517+
if isinstance(config.eos_token_id, list)
1518+
else config.eos_token_id
1519+
)
1520+
elif tokenizer.eos_token_id is not None:
1521+
tokenizer.pad_token_id = tokenizer.eos_token_id
1522+
else:
1523+
tokenizer.pad_token_id = 0
15051524
super().__init__(
15061525
model_id=model_id,
15071526
model=model,
@@ -2274,14 +2293,21 @@ def generate_token(
22742293
),
22752294
self.bucketing_ctx.get_padded_prompt_batch_size(len(batch)),
22762295
self.max_total_tokens,
2296+
self.tokenizer.pad_token_id,
22772297
)
22782298
else:
22792299
batch.prepare_for_prefill(
2280-
batch.max_input_length, len(batch), self.max_total_tokens
2300+
batch.max_input_length,
2301+
len(batch),
2302+
self.max_total_tokens,
2303+
self.tokenizer.pad_token_id,
22812304
)
22822305
else:
22832306
batch.prepare_for_decode(
2284-
self.dtype, self.use_contiguous_pa, self.bucketing_ctx
2307+
self.dtype,
2308+
self.use_contiguous_pa,
2309+
self.bucketing_ctx,
2310+
self.tokenizer.pad_token_id,
22852311
)
22862312
if hasattr(self, "set_inputs_embeds") and callable(self.set_inputs_embeds):
22872313
self.set_inputs_embeds(batch)

backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -554,10 +554,10 @@ def from_pb_processor(
554554
return batch
555555

556556
def prepare_for_prefill(
557-
self, max_padded_input_len, max_padded_bs, max_total_tokens
557+
self, max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id
558558
):
559559
super().prepare_for_prefill(
560-
max_padded_input_len, max_padded_bs, max_total_tokens
560+
max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id
561561
)
562562

563563
self.has_image_inputs = False

backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,10 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
4747
cross_attention_states: Optional[torch.Tensor] = None
4848

4949
def prepare_for_prefill(
50-
self, max_padded_input_len, max_padded_bs, max_total_tokens
50+
self, max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id
5151
):
5252
super(FlashVlmCausalLMBatch, self).prepare_for_prefill(
53-
max_padded_input_len, max_padded_bs, max_total_tokens
53+
max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id
5454
)
5555

5656
@classmethod

0 commit comments

Comments
 (0)