Skip to content

Commit 9c5ec4a

Browse files
committed
change HPU warmup logic: seq length should be with exponential growth
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
1 parent 329f612 commit 9c5ec4a

File tree

1 file changed

+22
-11
lines changed
  • backends/gaudi/server/text_generation_server/models

1 file changed

+22
-11
lines changed

backends/gaudi/server/text_generation_server/models/causal_lm.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
5757
LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1))
5858
BATCH_SIZE_EXPONENT_BASE = int(os.environ.get("BATCH_SIZE_EXPONENT_BASE", 2))
59+
SEQ_LEN_EXPONENT_BASE = int(os.environ.get("SEQ_LEN_EXPONENT_BASE", 2))
5960
MAX_BATCH_SIZE = (
6061
int(os.environ.get("MAX_BATCH_SIZE"))
6162
if os.environ.get("MAX_BATCH_SIZE") is not None
@@ -71,8 +72,21 @@ def torch_compile_for_eager(func):
7172
)
7273

7374

74-
def round_up_seq(number, k):
75-
return (number + k - 1) // k * k
75+
def round_up_seq(number, k, base):
76+
exponent = math.ceil(math.log(number / k, base))
77+
return k * (base**exponent)
78+
79+
80+
def iterate_powers_of_base(max_value, start, base):
81+
current = start
82+
result = []
83+
assert (
84+
max_value >= start
85+
), f"max_value {max_value} must be greater than start {start}"
86+
while current < max_value:
87+
result.append(current)
88+
current *= base
89+
return result
7690

7791

7892
def round_up_batch(number):
@@ -575,7 +589,9 @@ def from_pb(
575589
assert (
576590
PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length
577591
), "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
578-
rounded_seq_len = round_up_seq(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF)
592+
rounded_seq_len = round_up_seq(
593+
input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF, SEQ_LEN_EXPONENT_BASE
594+
)
579595
if rounded_seq_len <= max_input_length:
580596
bucket_size = rounded_seq_len - 1
581597
else:
@@ -1345,14 +1361,9 @@ def warmup(
13451361
max_exp + 1,
13461362
)
13471363
]
1348-
prefill_seqlen_list = [
1349-
seq
1350-
for seq in range(
1351-
PAD_SEQUENCE_TO_MULTIPLE_OF,
1352-
max_input_tokens,
1353-
PAD_SEQUENCE_TO_MULTIPLE_OF,
1354-
)
1355-
]
1364+
prefill_seqlen_list = iterate_powers_of_base(
1365+
max_input_tokens, PAD_SEQUENCE_TO_MULTIPLE_OF, SEQ_LEN_EXPONENT_BASE
1366+
)
13561367
prefill_seqlen_list.append(max_input_tokens)
13571368
prefill_batch_size_list.sort(reverse=True)
13581369
prefill_seqlen_list.sort(reverse=True)

0 commit comments

Comments
 (0)