I’ve open-sourced mamba2-jax, an experimental but stable JAX/Flax implementation of Mamba2 (“Transformers are SSMs”, ao & Gu, ICML 2024).
– GitHub: https://github.com/CosmoNaught/mamba2-jax
– PyPI: https://pypi.org/project/mamba2-jax/
The goal is to provide a pure JAX alternative to vasqu’s excellent PyTorch implementation, for people who are already in the JAX ecosystem or want TPU-native Mamba2 blocks without Triton/CUA kernels.
What's in the box?
- Mamba2 core in JAX/Flax (no Triton / custom CUA)
Mamba2ForCausalLMfor causal LMMamba2Forecasterfor time-series forecasting- Hooks for streaming/stateful inference and
output_hidden_states=True - Runs on CPU / CUA / TPU wherever JAX runs
Validation vs PyTorch
Small CPU-only parity test vs mamba2-torch on a synthetic MSE regression task:
- Similar loss curves; final MSE diff ≈ 0.012
- Prediction Pearson r ≈ 0.99
- After JIT warmup, JAX is ≈ 2.2× faster per step on CPU
mamba2-jax vs mamba2-pytorch validation (small numerical stability test)
Full details can be found [here](https://github.com/CosmoNaught/mamba2-jax/blob/main/REAME.md#numerical-validation-with-pytorch) in the repo.
Status / caveats
- Validated across CPUs, CUA GPUs, Apple Silicon / M-series (MPS), and Google Cloud TPUs. So you should be good to go!
- Alpha, API may still move a bit
- No pretrained weights yet
- GPU/TPU support is functional but not heavily profiled (not had time yet sadly!)
Feedback welcome on
- API design for research use
- Missing hooks for analysis / custom losses
- Real-world benchmarks on larger models or longer sequences
I’m an independent researcher (not affiliated with the original Mamba2 or JAX teams) and would really appreciate any feedback or bug reports!!
Thanks everyone for your time have a great day!