@@ -975,7 +975,7 @@ def concatenate(
975
975
valid_indices = None ,
976
976
)
977
977
978
- def prepare_for_decode (self , dtype , use_contiguous_pa , bucketing_ctx ):
978
+ def prepare_for_decode (self , dtype , use_contiguous_pa , bucketing_ctx , pad_token_id ):
979
979
block_num = [length // BLOCK_SIZE + 1 for length in self .cache_lengths ]
980
980
block_tables = []
981
981
for i , bt in enumerate (self .block_tables ):
@@ -998,7 +998,7 @@ def prepare_for_decode(self, dtype, use_contiguous_pa, bucketing_ctx):
998
998
bucketing_ctx ,
999
999
)
1000
1000
self .input_ids = F .pad (
1001
- self .input_ids , (0 , padded_bs - self .input_ids .shape [0 ]), value = 0
1001
+ self .input_ids , (0 , padded_bs - self .input_ids .shape [0 ]), value = pad_token_id
1002
1002
)
1003
1003
1004
1004
if self .position_ids .dim () == 2 :
@@ -1040,7 +1040,7 @@ def prepare_for_decode(self, dtype, use_contiguous_pa, bucketing_ctx):
1040
1040
)
1041
1041
1042
1042
def prepare_for_prefill (
1043
- self , max_padded_input_len , max_padded_bs , max_total_tokens
1043
+ self , max_padded_input_len , max_padded_bs , max_total_tokens , pad_token_id
1044
1044
):
1045
1045
# Prepare values if we need to continue prefilling
1046
1046
# Speculation must be ignored while we prefill even with chunking
@@ -1064,18 +1064,23 @@ def prepare_for_prefill(
1064
1064
for input_id in self .input_ids :
1065
1065
padded = self .max_input_length - len (input_id ) + extra_pad
1066
1066
if padded > 0 :
1067
- input_id = [0 ] * padded + input_id
1067
+ input_id = [pad_token_id ] * padded + input_id
1068
1068
input_ids .append (input_id )
1069
1069
input_ids_padded_length .append (padded )
1070
1070
input_ids = np .concatenate (input_ids , dtype = np .int64 )
1071
1071
self .input_ids = torch .tensor (input_ids , dtype = torch .int64 , device = device )
1072
1072
elif isinstance (self .input_ids , list ):
1073
1073
input_ids = self .input_ids [0 ]
1074
1074
input_ids_padded_length .append (extra_pad )
1075
- input_ids = [0 ] * extra_pad + input_ids
1075
+ input_ids = [pad_token_id ] * extra_pad + input_ids
1076
1076
self .input_ids = torch .tensor (input_ids , dtype = torch .int64 , device = device )
1077
1077
else :
1078
- input_ids = self .input_ids .new_zeros (max_padded_input_len * len (self ))
1078
+ input_ids = torch .full (
1079
+ (max_padded_input_len * len (self ),),
1080
+ pad_token_id ,
1081
+ dtype = torch .int64 ,
1082
+ device = self .input_ids .device ,
1083
+ )
1079
1084
src_pos = 0
1080
1085
for i in range (len (self )):
1081
1086
end_pos = (i + 1 ) * max_padded_input_len
@@ -1090,7 +1095,7 @@ def prepare_for_prefill(
1090
1095
self .input_ids = input_ids
1091
1096
1092
1097
self .input_ids = F .pad (
1093
- self .input_ids , (0 , extra_pad_bs * max_padded_input_len ), value = 0
1098
+ self .input_ids , (0 , extra_pad_bs * max_padded_input_len ), value = pad_token_id
1094
1099
)
1095
1100
1096
1101
self .input_lengths_tensor = torch .tensor (self .input_lengths , dtype = torch .int32 )
@@ -1312,8 +1317,9 @@ def prepare_for_prefill(
1312
1317
self .prefill_next_token_indices = (
1313
1318
self .prefill_next_token_indices + input_ids_padded_length_tensor
1314
1319
)
1315
- all_input_ids_tensor = torch .zeros (
1320
+ all_input_ids_tensor = torch .full (
1316
1321
(max_padded_bs , max (max_total_tokens , self .all_input_ids_tensor .shape [- 1 ])),
1322
+ pad_token_id ,
1317
1323
dtype = torch .int64 ,
1318
1324
device = "hpu" ,
1319
1325
)
@@ -1502,6 +1508,19 @@ def __init__(
1502
1508
)
1503
1509
self .skip_warmup = os .getenv ("VLLM_SKIP_WARMUP" , "false" ).lower () == "true"
1504
1510
self .max_seq_len_to_capture = 8192
1511
+ if tokenizer .pad_token_id is None :
1512
+ if config .pad_token_id is not None :
1513
+ tokenizer .pad_token_id = config .pad_token_id
1514
+ elif config .eos_token_id is not None :
1515
+ tokenizer .pad_token_id = (
1516
+ config .eos_token_id [0 ]
1517
+ if isinstance (config .eos_token_id , list )
1518
+ else config .eos_token_id
1519
+ )
1520
+ elif tokenizer .eos_token_id is not None :
1521
+ tokenizer .pad_token_id = tokenizer .eos_token_id
1522
+ else :
1523
+ tokenizer .pad_token_id = 0
1505
1524
super ().__init__ (
1506
1525
model_id = model_id ,
1507
1526
model = model ,
@@ -2274,14 +2293,21 @@ def generate_token(
2274
2293
),
2275
2294
self .bucketing_ctx .get_padded_prompt_batch_size (len (batch )),
2276
2295
self .max_total_tokens ,
2296
+ self .tokenizer .pad_token_id ,
2277
2297
)
2278
2298
else :
2279
2299
batch .prepare_for_prefill (
2280
- batch .max_input_length , len (batch ), self .max_total_tokens
2300
+ batch .max_input_length ,
2301
+ len (batch ),
2302
+ self .max_total_tokens ,
2303
+ self .tokenizer .pad_token_id ,
2281
2304
)
2282
2305
else :
2283
2306
batch .prepare_for_decode (
2284
- self .dtype , self .use_contiguous_pa , self .bucketing_ctx
2307
+ self .dtype ,
2308
+ self .use_contiguous_pa ,
2309
+ self .bucketing_ctx ,
2310
+ self .tokenizer .pad_token_id ,
2285
2311
)
2286
2312
if hasattr (self , "set_inputs_embeds" ) and callable (self .set_inputs_embeds ):
2287
2313
self .set_inputs_embeds (batch )
0 commit comments