You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When running simple inference with Mamba 2 on H100 on long sequence lengths=512000. I am hitting illegal memory access in _mamba_chunk_scan_combined_fwd:
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/ssd_combined.py", line 315, in _mamba_chunk_scan_combined_fwd
dA_cumsum, dt = _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/ssd_chunk_state.py", line 684, in _chunk_cumsum_fwd
_chunk_cumsum_fwd_kernel[grid_chunk_cs](
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 345, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 156, in run
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 156, in <dictcomp>
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 133, in _bench
return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8))
File "/usr/local/lib/python3.10/dist-packages/triton/testing.py", line 104, in do_bench
torch.cuda.synchronize()
File "/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py", line 950, in synchronize
return torch._C._cuda_synchronize()
RuntimeError: CUDA error: an illegal memory access was encountered
Here is a small reproducer:
import torch
from mamba_ssm import Mamba2
import torch.cuda as cuda
def get_gpu_memory():
"""Return GPU memory usage in MB"""
return cuda.memory_allocated() / 1024**2, cuda.memory_reserved() / 1024**2
def print_memory_usage(step):
allocated, reserved = get_gpu_memory()
print(f"{step}:")
print(f" Allocated: {allocated:.2f} MB")
print(f" Reserved: {reserved:.2f} MB")
print("-" * 40)
# Model parameters
d_model = 4096
d_state = 128
d_conv = 4
expand = 2
seq_len = 512000
batch_size = 1
# Initialize device and print initial memory state
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Initial GPU memory state")
print_memory_usage("Before model creation")
# Create model
model = Mamba2(
d_model=d_model,
d_state=d_state,
d_conv=d_conv,
expand=expand,
).to(device)
print("After model creation")
print_memory_usage("After moving model to GPU")
# Create input tensor
x = torch.randn(batch_size, seq_len, d_model).to(device)
print_memory_usage("After creating input tensor")
# Run inference multiple times
n_repeats = 3
with torch.no_grad():
for i in range(n_repeats):
# Clear cache before each run
torch.cuda.empty_cache()
print(f"\nInference run {i+1}")
print_memory_usage("Before inference")
# Run inference
output = model(x)
print_memory_usage("After inference")
# Verify output shape
assert output.shape == x.shape, f"Expected shape {x.shape}, got {output.shape}"
print(f"Output shape: {output.shape}")
# Final memory state
print("\nFinal GPU memory state")
print_memory_usage("After all runs")
# Model size information
total_params = sum(p.numel() for p in model.parameters())
print(f"\nModel Information:")
print(f"Total parameters: {total_params:,}")
print(f"Theoretical model size: {total_params * 4 / 1024**2:.2f} MB (FP32)")
The first inference seems to run fine but then the second there is an illegal memory access. Is there some modifications needed to the mamba2 kernels to for increasing sequence length? I am assuming this may be an range indexing issue of some kind.
The text was updated successfully, but these errors were encountered:
When running simple inference with Mamba 2 on H100 on long sequence lengths=
512000
. I am hitting illegal memory access in _mamba_chunk_scan_combined_fwd:Here is a small reproducer:
The first inference seems to run fine but then the second there is an illegal memory access. Is there some modifications needed to the mamba2 kernels to for increasing sequence length? I am assuming this may be an range indexing issue of some kind.
The text was updated successfully, but these errors were encountered: