-
Notifications
You must be signed in to change notification settings - Fork 572
Description
🐛 Bug
When using a fantasy model with a batch dimension, and retrieving the posterior with a set using more batch dimensions, internally GPyTorch / PyTorch tries to create excessively large matrices.
To reproduce
The following example uses gpytorch.settings.fast_pred_var()
which is what botorch uses and is the case which I've been investigating locally. However, it blows up with memory issues in a different place without this setting.
import torch
import gpytorch.settings
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import RBFKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.means import ConstantMean
from gpytorch.models import ExactGP
torch.set_default_dtype(torch.double)
class SimpleGP(ExactGP):
def __init__(self, train_inputs, train_targets):
super().__init__(train_inputs, train_targets, GaussianLikelihood())
self.mean_module = ConstantMean()
self.covar_module = RBFKernel()
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return MultivariateNormal(mean_x, covar_x)
if __name__ == "__main__":
d = 2
n_train = 100
gp = SimpleGP(
train_inputs=torch.rand(n_train, d, dtype=torch.double),
train_targets=torch.rand(n_train, dtype=torch.double),
).eval()
gp(torch.rand(5, d, dtype=torch.double)) # set the caches before fantasize.
num_fantasies = 64
x = torch.rand(256, 1, d)
y = torch.rand(num_fantasies, 256, 1)
fantasy_model = gp.get_fantasy_model(x, y).eval()
x_test = torch.rand(50, num_fantasies, 256, 1, d)
with gpytorch.settings.fast_pred_var():
fantasy_model(x_test) # Tries to allocate a tensor of 64GB
Traceback (most recent call last):
File ".../mwe_fantasize_memory_issue.py", line 43, in <module>
fantasy_model(x_test) # Tries to allocate a tensor of 64GB
~~~~~~~~~~~~~^^^^^^^^
File ".../lib/python3.13/site-packages/gpytorch/models/exact_gp.py", line 345, in __call__
) = self.prediction_strategy.exact_prediction(full_mean, full_covar)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^
File ".../lib/python3.13/site-packages/gpytorch/models/exact_prediction_strategies.py", line 325, in exact_prediction
self.exact_predictive_covar(test_test_covar, test_train_covar),
~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../lib/python3.13/site-packages/gpytorch/models/exact_prediction_strategies.py", line 411, in exact_predictive_covar
covar_inv_quad_form_root = self._exact_predictive_covar_inv_quad_form_root(precomputed_cache, test_train_covar)
File ".../lib/python3.13/site-packages/gpytorch/models/exact_prediction_strategies.py", line 117, in _exact_predictive_covar_inv_quad_form_root
return test_train_covar.matmul(precomputed_cache)
~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^
RuntimeError: [enforce fail at alloc_cpu.cpp:118] err == 0. DefaultCPUAllocator: can't allocate memory: you tried to allocate 66853273600 bytes. Error code 12 (Cannot allocate memory)
Expected Behavior
I would expect GPyTorch to be able to do this computation without excessive memory requests.
System information
- GPyTorch Version: 1.14
- PyTorch Version: 2.6.0+cpu
- Computer OS: Ubuntu 20.04.6 LTS (Focal Fossa)
Additional context
The memory request occurs on the line return test_train_covar.matmul(precomputed_cache)
. The sizes of the matrices are:
test_train_covar
:(50, 64, 256, 1, 101)
precomputed_cache
:(256, 101, 101)
I believe that internally, torch.matmul
is physically broadcasting precomputed_cache
to shape (50, 64, 256, 101, 101)
which causes the memory issue. I also believe that this is unnecessary (see this pytorch issue pytorch/pytorch#154128). However, I am creating this issue in GPyTorch as well because (a) perhaps there is a simpler fix in gpytorch because we have better information on the dimensions of the tensors (although maybe not - I'm not sure what dimensions are possible), and (b) the memory issue occurs in a different place without the gpytorch.settings.fast_pred_var()
setting (I think during cholesky factorisation) and my suggested fix in PyTorch would not fix that.