awesome-jax

JAX library catalog

A curated list of resources and libraries for JAX machine learning research

JAX - A curated list of resources https://github.com/google/jax

GitHub

2k stars
50 watching
132 forks
last commit: 4 months ago
Linked from 2 awesome lists

autogradawesomeawesome-listdeep-learningjaxmachine-learningneural-networknumpyxla

Awesome JAX / Libraries

Neural Network Libraries

Awesome JAX / Libraries / Neural Network Libraries

Centered on flexibility and clarity
Haiku 2,907 13 days ago Focused on simplicity, created by the authors of Sonnet at DeepMind
Objax 771 10 months ago Has an object oriented design similar to PyTorch
Elegy A High Level API for Deep Learning in JAX. Supports Flax, Haiku, and Optax
Trax 8,096 2 months ago "Batteries included" deep learning library focused on providing solutions for common workloads
Jraph 1,375 8 months ago Lightweight graph neural network library
Neural Tangents 2,278 9 months ago High-level API for specifying neural networks of both finite and width
HuggingFace 135,022 6 days ago Ecosystem of pretrained Transformers for a wide range of natural language tasks (Flax)
Equinox 2,118 21 days ago Callable PyTrees and filtered JIT/grad transformations => neural networks in JAX
Scenic 3,332 about 1 month ago A Jax Library for Computer Vision Research and Beyond

Awesome JAX / Libraries

Levanter 516 3 days ago Legible, Scalable, Reproducible Foundation Models with Named Tensors and JAX
EasyLM 2,409 3 months ago LLMs made easy: Pre-training, finetuning, evaluating and serving LLMs in JAX/Flax
NumPyro 2,290 4 days ago Probabilistic programming based on the Pyro library
Chex 788 8 days ago Utilities to write and test reliable JAX code
Optax 1,697 9 days ago Gradient processing and optimization library
RLax 1,263 about 2 months ago Library for implementing reinforcement learning agents
JAX, M.D. 1,185 21 days ago Accelerated, differential molecular dynamics
Coax 167 almost 2 years ago Turn RL papers into code, the easy way
Distrax 536 2 months ago Reimplementation of TensorFlow Probability, containing probability distributions and bijectors
cvxpylayers 1,819 9 days ago Construct differentiable convex optimization layers
TensorLy 1,566 9 days ago Tensor learning made simple
NetKet 548 6 days ago Machine Learning toolbox for Quantum Physics
Fortuna 893 27 days ago AWS library for Uncertainty Quantification in Deep Learning
BlackJAX 846 22 days ago Library of samplers for JAX

Awesome JAX / Libraries / New Libraries

Neural Network Libraries

Awesome JAX / Libraries / New Libraries / Neural Network Libraries

Federated learning in JAX, built on Optax and Haiku
Equivariant MLP 257 over 1 year ago Construct equivariant neural network layers
jax-resnet 104 over 2 years ago Implementations and checkpoints for ResNet variants in Flax
Parallax 155 over 4 years ago Immutable Torch Modules for JAX

Awesome JAX / Libraries / New Libraries

jax-unirep 104 3 months ago Library implementing the for protein machine learning applications
jax-flows 274 over 1 year ago Normalizing flows in JAX
sklearn-jax-kernels 42 about 4 years ago kernel matrices using JAX
jax-cosmo 178 4 months ago Differentiable cosmology library
efax 55 4 days ago Exponential Families in JAX
mpi4jax 445 17 days ago Combine MPI operations with your Jax code on CPUs and GPUs
imax 37 8 months ago Image augmentations and transformations
FlaxVision 44 almost 4 years ago Flax version of TorchVision
Oryx 4,269 6 days ago Probabilistic programming language based on program transformations
Optimal Transport Tools 213 almost 3 years ago Toolbox that bundles utilities to solve optimal transport problems
delta PV 60 almost 2 years ago A photovoltaic simulator with automatic differentation
jaxlie 233 2 months ago Lie theory library for rigid body transformations and optimization
BRAX 2,337 7 days ago Differentiable physics engine to simulate environments along with learning algorithms to train agents for these environments
flaxmodels 238 over 1 year ago Pretrained models for Jax/Flax
CR.Sparse 88 about 1 year ago XLA accelerated algorithms for sparse representations and compressive sensing
exojax 57 about 1 month ago Automatic differentiable spectrum modeling of exoplanets/brown dwarfs compatible to JAX
JAXopt 933 2 months ago Hardware accelerated (GPU/TPU), batchable and differentiable optimizers in JAX
PIX 389 13 days ago PIX is an image processing library in JAX, for JAX
bayex 84 7 months ago Bayesian Optimization powered by JAX
JaxDF 121 2 months ago Framework for differentiable simulators with arbitrary discretizations
tree-math 188 6 months ago Convert functions that operate on arrays into functions that operate on PyTrees
jax-models 151 over 2 years ago Implementations of research papers originally without code or code written with frameworks other than JAX
PGMax 64 about 1 month ago A framework for building discrete Probabilistic Graphical Models (PGM's) and running inference inference on them via JAX
EvoJAX 843 5 months ago Hardware-Accelerated Neuroevolution
evosax 506 about 1 month ago JAX-Based Evolution Strategies
SymJAX 120 over 1 year ago Symbolic CPU/GPU/TPU programming
mcx 325 8 months ago Express & compile probabilistic programs for performant inference
Einshape 99 5 months ago DSL-based reshaping library for JAX and other frameworks
ALX 34,295 6 days ago Open-source library for distributed matrix factorization using Alternating Least Squares, more info in
Diffrax 1,442 3 days ago Numerical differential equation solvers in JAX
tinygp 296 11 days ago The of Gaussian process libraries in JAX
gymnax 650 5 months ago Reinforcement Learning Environments with the well-known gym API
Mctx 2,356 4 months ago Monte Carlo tree search algorithms in native JAX
KFAC-JAX 248 5 days ago Second Order Optimization with Approximate Curvature for NNs
TF2JAX 105 17 days ago Convert functions/graphs to JAX functions
jwave 143 2 months ago A library for differentiable acoustic simulations
GPJax 461 20 days ago Gaussian processes in JAX
Jumanji 622 6 days ago A Suite of Industry-Driven Hardware-Accelerated RL Environments written in JAX
Eqxvision 102 4 months ago Equinox version of Torchvision
JAXFit 51 over 1 year ago Accelerated curve fitting library for nonlinear least-squares problems (see )
econpizza 78 9 days ago Solve macroeconomic models with hetereogeneous agents using JAX
SPU 242 3 days ago A domain-specific compiler and runtime suite to run JAX code with MPC(Secure Multi-Party Computation)
jax-tqdm 96 28 days ago Add a tqdm progress bar to JAX scans and loops
safejax 42 6 months ago Serialize JAX, Flax, Haiku, or Objax model params with 🤗
Kernex 66 about 1 year ago Differentiable stencil decorators in JAX
MaxText 1,529 4 days ago A simple, performant and scalable Jax LLM written in pure Python/Jax and targeting Google Cloud TPUs
Pax 457 9 days ago A Jax-based machine learning framework for training large scale models
Praxis 178 10 days ago The layer library for Pax with a goal to be usable by other JAX-based ML projects
purejaxrl 722 2 months ago Vectorisable, end-to-end RL algorithms in JAX
Lorax 132 9 months ago Automatically apply LoRA to JAX models (Flax, Haiku, etc.)
SCICO 105 13 days ago Scientific computational imaging in JAX
Spyx 101 about 1 month ago Spiking Neural Networks in JAX for machine learning on neuromorphic hardware
BrainPy 533 9 days ago Brain Dynamics Programming in Python
OTT-JAX 525 3 days ago Optimal transport tools in JAX
QDax 266 about 1 month ago Quality Diversity optimization in Jax
JAX Toolbox 245 3 days ago Nightly CI and optimized examples for JAX on NVIDIA GPUs using libraries such as T5x, Paxml, and Transformer Engine
Pgx Vectorized board game environments for RL with an AlphaZero example
EasyDeL 206 5 days ago EasyDeL 🔮 is an OpenSource Library to make your training faster and more Optimized With cool Options for training and serving (Llama, MPT, Mixtral, Falcon, etc) in JAX
XLB 229 17 days ago A Differentiable Massively Parallel Lattice Boltzmann Library in Python for Physics-Based Machine Learning
dynamiqs 161 6 days ago High-performance and differentiable simulations of quantum systems with JAX

Awesome JAX / Models and Projects / JAX

Official implementation of
kalman-jax 95 over 1 year ago Approximate inference for Markov (i.e., temporal) Gaussian processes using iterated Kalman filtering and smoothing
jaxns 147 8 days ago Nested sampling in JAX
Amortized Bayesian Optimization 34,295 6 days ago Code related to
Accurate Quantized Training 34,295 6 days ago Tools and libraries for running and analyzing neural network quantization experiments in JAX and Flax
BNN-HMC 34,295 6 days ago Implementation for the paper
JAX-DFT 34,295 6 days ago One-dimensional density functional theory (DFT) in JAX, with implementation of
Robust Loss 34,295 6 days ago Reference code for the paper
Symbolic Functionals 34,295 6 days ago Demonstration from
TriMap 34,295 6 days ago Official JAX implementation of

Awesome JAX / Models and Projects / Flax

Performer 34,295 6 days ago Flax implementation of the Performer (linear transformer via FAVOR+) architecture
JaxNeRF 34,295 6 days ago Implementation of with multi-device GPU/TPU support
mip-NeRF 900 about 2 years ago Official implementation of
RegNeRF 34,295 6 days ago Official implementation of
Big Transfer (BiT) 1,513 4 months ago Implementation of
JAX RL 630 about 2 years ago Implementations of reinforcement learning algorithms
gMLP Implementation of
MLP Mixer Minimal implementation of
Distributed Shampoo 34,295 6 days ago Implementation of
NesT 193 4 months ago Official implementation of
XMC-GAN 98 27 days ago Official implementation of
FNet 34,295 6 days ago Official implementation of
GFSA 34,295 6 days ago Official implementation of
IPA-GNN 34,295 6 days ago Official implementation of
Flax Models 34,295 6 days ago Collection of models and methods implemented in Flax
Protein LM 34,295 6 days ago Implements BERT and autoregressive models for proteins, as described in and
Slot Attention 34,295 6 days ago Reference implementation for
Vision Transformer 10,450 6 months ago Official implementation of
FID computation 24 4 months ago Port of to Flax
ARDM 34,295 6 days ago Official implementation of
D3PM 34,295 6 days ago Official implementation of
Gumbel-max Causal Mechanisms 34,295 6 days ago Code for , with extra code in
Latent Programmer 34,295 6 days ago Code for the ICML 2021 paper
SNeRG 34,295 6 days ago Official implementation of
Spin-weighted Spherical CNNs 34,295 6 days ago Adaptation of
VDVAE 34,295 6 days ago Adaptation of , original code at
MUSIQ 34,295 6 days ago Checkpoints and model inference code for the ICCV 2021 paper
AQuaDem 34,295 6 days ago Official implementation of
Combiner 34,295 6 days ago Official implementation of
Dreamfields 34,295 6 days ago Official implementation of the ICLR 2022 paper
GIFT 34,295 6 days ago Official implementation of
Light Field Neural Rendering 34,295 6 days ago Official implementation of
Sharpened Cosine Similarity in JAX by Raphael Pisoni A JAX/Flax implementation of the Sharpened Cosine Similarity layer
GNNs for Solving Combinatorial Optimization Problems 42 almost 2 years ago A JAX + Flax implementation of

Awesome JAX / Models and Projects / Haiku

AlphaFold 12,859 5 months ago Implementation of the inference pipeline of AlphaFold v2.0, presented in
Adversarial Robustness 13,250 26 days ago Reference code for and
Bootstrap Your Own Latent 13,250 26 days ago Implementation for the paper
Gated Linear Networks 13,250 26 days ago GLNs are a family of backpropagation-free neural networks
Glassy Dynamics 13,250 26 days ago Open source implementation of the paper
MMV 13,250 26 days ago Code for the models in
Normalizer-Free Networks 13,250 26 days ago Official Haiku implementation of
NuX 82 12 months ago Normalizing flows with JAX
OGB-LSC 13,250 26 days ago This repository contains DeepMind's entry to the (quantum chemistry) and (academic graph) tracks of the (OGB-LSC)
Persistent Evolution Strategies 34,295 6 days ago Code used for the paper
Two Player Auction Learning 0 12 months ago JAX implementation of the paper
WikiGraphs 13,250 26 days ago Baseline code to reproduce results in

Awesome JAX / Models and Projects / Trax

Reformer 8,096 2 months ago Implementation of the Reformer (efficient transformer) architecture

Awesome JAX / Models and Projects / NumPyro

lqg 23 10 months ago Official implementation of Bayesian inverse optimal control for linear-quadratic Gaussian problems from the paper

Awesome JAX / Videos

JAX, its use at DeepMind, and discussion between engineers, scientists, and JAX core team
Introduction to JAX Simple neural network from scratch in JAX
JAX: Accelerated Machine Learning Research | SciPy 2020 | VanderPlas JAX's core design, how it's powering new research, and how you can start using it
Bayesian Programming with JAX + NumPyro — Andy Kitchen Introduction to Bayesian modelling using NumPyro
JAX: Accelerated machine-learning research via composable function transformations in Python | NeurIPS 2019 | Skye Wanderman-Milne JAX intro presentation in workshop
JAX on Cloud TPUs | NeurIPS 2020 | Skye Wanderman-Milne and James Bradbury Presentation of TPU host access with demo
Deep Implicit Layers - Neural ODEs, Deep Equilibirum Models, and Beyond | NeurIPS 2020 Tutorial created by Zico Kolter, David Duvenaud, and Matt Johnson with Colab notebooks avaliable in
Solving y=mx+b with Jax on a TPU Pod slice - Mat Kelcey A four part YouTube tutorial series with Colab notebooks that starts with Jax fundamentals and moves up to training with a data parallel approach on a v3-32 TPU Pod slice
JAX, Flax & Transformers 🤗 135,022 6 days ago 3 days of talks around JAX / Flax, Transformers, large-scale language modeling and other great topics

Awesome JAX / Papers

Compiling machine learning programs via high-level tracing. Roy Frostig, Matthew James Johnson, Chris Leary. MLSys 2018. White paper describing an early version of JAX, detailing how computation is traced and compiled
JAX, M.D.: A Framework for Differentiable Physics. Samuel S. Schoenholz, Ekin D. Cubuk. NeurIPS 2020. Introduces JAX, M.D., a differentiable physics library which includes simulation environments, interaction potentials, neural networks, and more
Enabling Fast Differentially Private SGD via Just-in-Time Compilation and Vectorization. Pranav Subramani, Nicholas Vadivelu, Gautam Kamath. arXiv 2020. Uses JAX's JIT and VMAP to achieve faster differentially private than existing libraries
XLB: A Differentiable Massively Parallel Lattice Boltzmann Library in Python. Mohammadmehdi Ataei, Hesam Salehipour. arXiv 2023. White paper describing the XLB library: benchmarks, validations, and more details about the library

Awesome JAX / Tutorials and Blog Posts

Describes the state of JAX and the JAX ecosystem at DeepMind
Getting started with JAX (MLPs, CNNs & RNNs) by Robert Lange Neural network building blocks from scratch with the basic JAX operators
Learn JAX: From Linear Regression to Neural Networks by Rito Ghosh A gentle introduction to JAX and using it to implement Linear and Logistic Regression, and Neural Network models and using them to solve real world problems
Tutorial: image classification with JAX and Flax Linen by 8bitmp3 24 over 1 year ago Learn how to create a simple convolutional network with the Linen API by Flax and train it to recognize handwritten digits
Plugging Into JAX by Nick Doiron Compares Flax, Haiku, and Objax on the Kaggle flower classification challenge
Meta-Learning in 50 Lines of JAX by Eric Jang Introduction to both JAX and Meta-Learning
Normalizing Flows in 100 Lines of JAX by Eric Jang Concise implementation of
Differentiable Path Tracing on the GPU/TPU by Eric Jang Tutorial on implementing path tracing
Ensemble networks by Mat Kelcey Ensemble nets are a method of representing an ensemble of models as one single logical model
Out of distribution (OOD) detection by Mat Kelcey Implements different methods for OOD detection
Understanding Autodiff with JAX by Srihari Radhakrishna Understand how autodiff works using JAX
From PyTorch to JAX: towards neural net frameworks that purify stateful code by Sabrina J. Mielke Showcases how to go from a PyTorch-like style of coding to a more Functional-style of coding
Extending JAX with custom C++ and CUDA code by Dan Foreman-Mackey 378 3 months ago Tutorial demonstrating the infrastructure required to provide custom ops in JAX
Evolving Neural Networks in JAX by Robert Tjarko Lange Explores how JAX can power the next generation of scalable neuroevolution algorithms
Exploring hyperparameter meta-loss landscapes with JAX by Luke Metz Demonstrates how to use JAX to perform inner-loss optimization with SGD and Momentum, outer-loss optimization with gradients, and outer-loss optimization using evolutionary strategies
Deterministic ADVI in JAX by Martin Ingram Walk through of implementing automatic differentiation variational inference (ADVI) easily and cleanly with JAX
Evolved channel selection by Mat Kelcey Trains a classification model robust to different combinations of input channels at different resolutions, then uses a genetic algorithm to decide the best combination for a particular loss
Introduction to JAX by Kevin Murphy Colab that introduces various aspects of the language and applies them to simple ML problems
Writing an MCMC sampler in JAX by Jeremie Coullon Tutorial on the different ways to write an MCMC sampler in JAX along with speed benchmarks
How to add a progress bar to JAX scans and loops by Jeremie Coullon Tutorial on how to add a progress bar to compiled loops in JAX using the module
Get started with JAX by Aleksa Gordić 661 12 months ago A series of notebooks and videos going from zero JAX knowledge to building neural networks in Haiku
Writing a Training Loop in JAX + FLAX by Saurav Maheshkar and Soumik Rakshit A tutorial on writing a simple end-to-end training and evaluation pipeline in JAX, Flax and Optax
Implementing NeRF in JAX by Soumik Rakshit and Saurav Maheshkar A tutorial on 3D volumetric rendering of scenes represented by Neural Radiance Fields in JAX
Deep Learning tutorials with JAX+Flax by Phillip Lippe A series of notebooks explaining various deep learning concepts, from basics (e.g. intro to JAX/Flax, activiation functions) to recent advances (e.g., Vision Transformers, SimCLR), with translations to PyTorch
Achieving 4000x Speedups with PureJaxRL A blog post on how JAX can massively speedup RL training through vectorisation

Awesome JAX / Books

A hands-on guide to using JAX for deep learning and other mathematically-intensive applications

Awesome JAX / Community

JAX GitHub Discussions 30,499 7 days ago
Reddit

Backlinks from these awesome lists:

More related projects: