Skip to content

[Bug] Posterior on fantasy models can be memory inefficient on batches #2652

@JackBuck

Description

@JackBuck

🐛 Bug

When using a fantasy model with a batch dimension, and retrieving the posterior with a set using more batch dimensions, internally GPyTorch / PyTorch tries to create excessively large matrices.

To reproduce

The following example uses gpytorch.settings.fast_pred_var() which is what botorch uses and is the case which I've been investigating locally. However, it blows up with memory issues in a different place without this setting.

import torch
import gpytorch.settings
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import RBFKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.means import ConstantMean
from gpytorch.models import ExactGP

torch.set_default_dtype(torch.double)


class SimpleGP(ExactGP):
    def __init__(self, train_inputs, train_targets):
        super().__init__(train_inputs, train_targets, GaussianLikelihood())
        self.mean_module = ConstantMean()
        self.covar_module = RBFKernel()

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)


if __name__ == "__main__":
    d = 2
    n_train = 100
    gp = SimpleGP(
        train_inputs=torch.rand(n_train, d, dtype=torch.double),
        train_targets=torch.rand(n_train, dtype=torch.double),
    ).eval()
    gp(torch.rand(5, d, dtype=torch.double))  # set the caches before fantasize.

    num_fantasies = 64
    x = torch.rand(256, 1, d)
    y = torch.rand(num_fantasies, 256, 1)
    fantasy_model = gp.get_fantasy_model(x, y).eval()

    x_test = torch.rand(50, num_fantasies, 256, 1, d)
    with gpytorch.settings.fast_pred_var():
        fantasy_model(x_test)  # Tries to allocate a tensor of 64GB
Traceback (most recent call last):
  File ".../mwe_fantasize_memory_issue.py", line 43, in <module>
    fantasy_model(x_test)  # Tries to allocate a tensor of 64GB
    ~~~~~~~~~~~~~^^^^^^^^
  File ".../lib/python3.13/site-packages/gpytorch/models/exact_gp.py", line 345, in __call__
    ) = self.prediction_strategy.exact_prediction(full_mean, full_covar)
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.13/site-packages/gpytorch/models/exact_prediction_strategies.py", line 325, in exact_prediction
    self.exact_predictive_covar(test_test_covar, test_train_covar),
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.13/site-packages/gpytorch/models/exact_prediction_strategies.py", line 411, in exact_predictive_covar
    covar_inv_quad_form_root = self._exact_predictive_covar_inv_quad_form_root(precomputed_cache, test_train_covar)
  File ".../lib/python3.13/site-packages/gpytorch/models/exact_prediction_strategies.py", line 117, in _exact_predictive_covar_inv_quad_form_root
    return test_train_covar.matmul(precomputed_cache)
           ~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^
RuntimeError: [enforce fail at alloc_cpu.cpp:118] err == 0. DefaultCPUAllocator: can't allocate memory: you tried to allocate 66853273600 bytes. Error code 12 (Cannot allocate memory)

Expected Behavior

I would expect GPyTorch to be able to do this computation without excessive memory requests.

System information

  • GPyTorch Version: 1.14
  • PyTorch Version: 2.6.0+cpu
  • Computer OS: Ubuntu 20.04.6 LTS (Focal Fossa)

Additional context

The memory request occurs on the line return test_train_covar.matmul(precomputed_cache). The sizes of the matrices are:

  • test_train_covar: (50, 64, 256, 1, 101)
  • precomputed_cache: (256, 101, 101)

I believe that internally, torch.matmul is physically broadcasting precomputed_cache to shape (50, 64, 256, 101, 101) which causes the memory issue. I also believe that this is unnecessary (see this pytorch issue pytorch/pytorch#154128). However, I am creating this issue in GPyTorch as well because (a) perhaps there is a simpler fix in gpytorch because we have better information on the dimensions of the tensors (although maybe not - I'm not sure what dimensions are possible), and (b) the memory issue occurs in a different place without the gpytorch.settings.fast_pred_var() setting (I think during cholesky factorisation) and my suggested fix in PyTorch would not fix that.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions