Skip to content

Bug: No input gradients when using CPU/MPS fallback selective scan #7

@enriquebarco

Description

@enriquebarco

I apologies ahead of time, since I found this bug with Chat-GPT 5 and I need to move quickly, there will be some AI stuff written here. However, this is definitely a bug! No hallucinations here.

Description:
When running Mamba2MacOS or Mamba on CPU or MPS (Apple Silicon) in training mode, the fallback selective_scan_ref path does not propagate gradients to the input tensor.

This completely breaks training for models using Mamba blocks without CUDA — outputs change when inputs change, but .backward() does not produce non-zero x.grad.

Steps to reproduce:

import torch
from mamba_ssm import Mamba2MacOS

m = Mamba2MacOS(d_model=768, d_state=16, expand=2).to('cpu').train()
x = torch.randn(2, 100, 768, requires_grad=True)
y = m(x).sum()
y.backward()
print("x.grad.sum() =", float(x.grad.abs().sum()))
# Expected > 0, got 0.0

Expected Behavior:
x.grad should be non-zero in training mode for CPU/MPS when using the reference scan.

Cause:
mamba2_macos.py and mamba_simple.py import selective_scan_fn as a static function at module load time.
This means it always points to the fast kernel (or None), and replacing it with ssi.selective_scan_ref at runtime has no effect.
When no custom kernel is available, the current path leads to broken autograd.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions