56
56
CHUNK_SIZES = [1 , 2 , 4 , 8 , 16 , 32 , 64 , 128 , 256 , 512 , 1024 , 2048 ]
57
57
LAZY_MODE = int (os .environ .get ("PT_HPU_LAZY_MODE" , 1 ))
58
58
BATCH_SIZE_EXPONENT_BASE = int (os .environ .get ("BATCH_SIZE_EXPONENT_BASE" , 2 ))
59
+ SEQ_LEN_EXPONENT_BASE = int (os .environ .get ("SEQ_LEN_EXPONENT_BASE" , 2 ))
59
60
MAX_BATCH_SIZE = (
60
61
int (os .environ .get ("MAX_BATCH_SIZE" ))
61
62
if os .environ .get ("MAX_BATCH_SIZE" ) is not None
@@ -71,8 +72,21 @@ def torch_compile_for_eager(func):
71
72
)
72
73
73
74
74
- def round_up_seq (number , k ):
75
- return (number + k - 1 ) // k * k
75
+ def round_up_seq (number , k , base ):
76
+ exponent = math .ceil (math .log (number / k , base ))
77
+ return k * (base ** exponent )
78
+
79
+
80
+ def iterate_powers_of_base (max_value , start , base ):
81
+ current = start
82
+ result = []
83
+ assert (
84
+ max_value >= start
85
+ ), f"max_value { max_value } must be greater than start { start } "
86
+ while current < max_value :
87
+ result .append (current )
88
+ current *= base
89
+ return result
76
90
77
91
78
92
def round_up_batch (number ):
@@ -575,7 +589,9 @@ def from_pb(
575
589
assert (
576
590
PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length
577
591
), "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
578
- rounded_seq_len = round_up_seq (input_len + 1 , PAD_SEQUENCE_TO_MULTIPLE_OF )
592
+ rounded_seq_len = round_up_seq (
593
+ input_len + 1 , PAD_SEQUENCE_TO_MULTIPLE_OF , SEQ_LEN_EXPONENT_BASE
594
+ )
579
595
if rounded_seq_len <= max_input_length :
580
596
bucket_size = rounded_seq_len - 1
581
597
else :
@@ -1345,14 +1361,9 @@ def warmup(
1345
1361
max_exp + 1 ,
1346
1362
)
1347
1363
]
1348
- prefill_seqlen_list = [
1349
- seq
1350
- for seq in range (
1351
- PAD_SEQUENCE_TO_MULTIPLE_OF ,
1352
- max_input_tokens ,
1353
- PAD_SEQUENCE_TO_MULTIPLE_OF ,
1354
- )
1355
- ]
1364
+ prefill_seqlen_list = iterate_powers_of_base (
1365
+ max_input_tokens , PAD_SEQUENCE_TO_MULTIPLE_OF , SEQ_LEN_EXPONENT_BASE
1366
+ )
1356
1367
prefill_seqlen_list .append (max_input_tokens )
1357
1368
prefill_batch_size_list .sort (reverse = True )
1358
1369
prefill_seqlen_list .sort (reverse = True )
0 commit comments