[P] mamba2-jax is here! Pure JAX/Flax implementation of Mamba2 (≈2× faster CPU inference vs PyTorch on my micro-benchmark)

Hey guys!

I’ve open-sourced mamba2-jax, an experimental but stable JAX/Flax implementation of Mamba2 (“Transformers are SSMs”, ao & Gu, ICML 2024).

– GitHub: https://github.com/CosmoNaught/mamba2-jax

– PyPI: https://pypi.org/project/mamba2-jax/

The goal is to provide a pure JAX alternative to vasqu’s excellent PyTorch implementation, for people who are already in the JAX ecosystem or want TPU-native Mamba2 blocks without Triton/CUA kernels.

What's in the box?

  • Mamba2 core in JAX/Flax (no Triton / custom CUA)
  • Mamba2ForCausalLM for causal LM
  • Mamba2Forecaster for time-series forecasting
  • Hooks for streaming/stateful inference and output_hidden_states=True
  • Runs on CPU / CUA / TPU wherever JAX runs

Validation vs PyTorch

Small CPU-only parity test vs mamba2-torch on a synthetic MSE regression task:

  • Similar loss curves; final MSE diff ≈ 0.012
  • Prediction Pearson r ≈ 0.99
  • After JIT warmup, JAX is ≈ 2.2× faster per step on CPU

mamba2-jax vs mamba2-pytorch validation (small numerical stability test)

Full details can be found [here](https://github.com/CosmoNaught/mamba2-jax/blob/main/REAME.md#numerical-validation-with-pytorch) in the repo.

Status / caveats

  • Validated across CPUs, CUA GPUs, Apple Silicon / M-series (MPS), and Google Cloud TPUs. So you should be good to go!
  • Alpha, API may still move a bit
  • No pretrained weights yet
  • GPU/TPU support is functional but not heavily profiled (not had time yet sadly!)

Feedback welcome on

  • API design for research use
  • Missing hooks for analysis / custom losses
  • Real-world benchmarks on larger models or longer sequences

I’m an independent researcher (not affiliated with the original Mamba2 or JAX teams) and would really appreciate any feedback or bug reports!!

Thanks everyone for your time have a great day!

Leave a Reply