Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Long Sequence Length Inference Mamba2: CUDA error: an illegal memory access was encountered #686

Open
wdykas opened this issue Feb 5, 2025 · 0 comments

Comments

@wdykas
Copy link

wdykas commented Feb 5, 2025

When running simple inference with Mamba 2 on H100 on long sequence lengths=512000. I am hitting illegal memory access in _mamba_chunk_scan_combined_fwd:

  File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/ssd_combined.py", line 315, in _mamba_chunk_scan_combined_fwd
    dA_cumsum, dt = _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit)
  File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/ssd_chunk_state.py", line 684, in _chunk_cumsum_fwd
    _chunk_cumsum_fwd_kernel[grid_chunk_cs](
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 345, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 156, in run
    timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 156, in <dictcomp>
    timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 133, in _bench
    return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8))
  File "/usr/local/lib/python3.10/dist-packages/triton/testing.py", line 104, in do_bench
    torch.cuda.synchronize()
  File "/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py", line 950, in synchronize
    return torch._C._cuda_synchronize()
RuntimeError: CUDA error: an illegal memory access was encountered

Here is a small reproducer:

import torch
from mamba_ssm import Mamba2
import torch.cuda as cuda

def get_gpu_memory():
    """Return GPU memory usage in MB"""
    return cuda.memory_allocated() / 1024**2, cuda.memory_reserved() / 1024**2

def print_memory_usage(step):
    allocated, reserved = get_gpu_memory()
    print(f"{step}:")
    print(f"  Allocated: {allocated:.2f} MB")
    print(f"  Reserved:  {reserved:.2f} MB")
    print("-" * 40)

# Model parameters
d_model = 4096
d_state = 128
d_conv = 4
expand = 2
seq_len = 512000
batch_size = 1

# Initialize device and print initial memory state
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Initial GPU memory state")
print_memory_usage("Before model creation")

# Create model
model = Mamba2(
    d_model=d_model,
    d_state=d_state,
    d_conv=d_conv,
    expand=expand,
).to(device)

print("After model creation")
print_memory_usage("After moving model to GPU")

# Create input tensor
x = torch.randn(batch_size, seq_len, d_model).to(device)
print_memory_usage("After creating input tensor")

# Run inference multiple times
n_repeats = 3
with torch.no_grad():
    for i in range(n_repeats):
        # Clear cache before each run
        torch.cuda.empty_cache()
        print(f"\nInference run {i+1}")
        print_memory_usage("Before inference")
        
        # Run inference
        output = model(x)
        print_memory_usage("After inference")
        
        # Verify output shape
        assert output.shape == x.shape, f"Expected shape {x.shape}, got {output.shape}"
        print(f"Output shape: {output.shape}")

# Final memory state
print("\nFinal GPU memory state")
print_memory_usage("After all runs")

# Model size information
total_params = sum(p.numel() for p in model.parameters())
print(f"\nModel Information:")
print(f"Total parameters: {total_params:,}")
print(f"Theoretical model size: {total_params * 4 / 1024**2:.2f} MB (FP32)")

The first inference seems to run fine but then the second there is an illegal memory access. Is there some modifications needed to the mamba2 kernels to for increasing sequence length? I am assuming this may be an range indexing issue of some kind.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant