-
Notifications
You must be signed in to change notification settings - Fork 2.7k
Description
Describe the bug
When using batch_size="auto"
with the generate_until
function, the automatic batch size detection mechanism is overly conservative and often results in a batch size of 1. This significantly hinders performance and negates the benefit of the auto-sizing feature.
The root cause is that when generate_until
calls _detect_batch_size
, it doesn't pass any information about the actual lengths of the input requests.
As a result, _detect_batch_size
falls back to using the model's maximum possible sequence length (self.max_length
) to create its test tensor for estimating the batch size.
This means that even if all actual input prompts are short, the batch size is determined based on the worst-case scenario of a full-context input, which almost always leads to batch_size=1
on modern GPUs with large models. This behavior is different from how loglikelihood
functions use _detect_batch_size
, where they pass request-specific information.
To Reproduce
- Use a large language model.
- Run an evaluation that uses the
generate_until
method. - Set the batch size argument to
auto
, for example:--batch_size auto
. - Observe the log output, which will show:
Passed argument batch_size = auto. Detecting largest batch size Determined Largest batch size: 1
Expected behavior
The auto
batch size mechanism should detect a batch size larger than 1 by considering the actual input lengths and the max_gen_toks
parameter, leading to faster evaluations. The estimated sequence length should be more realistic.
Suggested Solution
A more accurate estimation can be achieved by calculating the expected maximum sequence length within generate_until
and passing it to _detect_batch_size
.
-
Modify
_detect_batch_size
to accept an optionalmax_length
argument. If provided, this length is used for the test tensor instead ofself.max_length
.def _detect_batch_size(self, requests: Sequence | None = None, pos: int = 0, max_length: int | None = None): if max_length is not None: # Use the provided max_length ... elif requests: # Current logic for loglikelihood ... else: # Fallback to self.max_length max_length = self.max_length ...
-
Update
generate_until
to calculate thismax_length
before calling_detect_batch_size
. The logic would be:- Find the maximum token length of all input contexts in the
requests
list. - Get
max_gen_toks
from the generation arguments. - Calculate
estimated_max_length = max_input_length + max_gen_toks
. - Call
_detect_batch_size(max_length=estimated_max_length)
.
- Find the maximum token length of all input contexts in the
This approach would provide a much more realistic sequence length for batch size estimation, unlocking the performance benefits of the auto
batch size feature for generation tasks.
Thank you for considering this improvement!