awesome-jax
JAX library catalog
A curated list of resources and libraries for JAX machine learning research
JAX - A curated list of resources https://github.com/google/jax
2k stars
50 watching
132 forks
last commit: 4 months ago
Linked from 2 awesome lists
autogradawesomeawesome-listdeep-learningjaxmachine-learningneural-networknumpyxla
Awesome JAX / Libraries | |||
Neural Network Libraries | |||
Awesome JAX / Libraries / Neural Network Libraries | |||
Centered on flexibility and clarity | |||
Haiku | 2,907 | 13 days ago | Focused on simplicity, created by the authors of Sonnet at DeepMind |
Objax | 771 | 10 months ago | Has an object oriented design similar to PyTorch |
Elegy | A High Level API for Deep Learning in JAX. Supports Flax, Haiku, and Optax | ||
Trax | 8,096 | 2 months ago | "Batteries included" deep learning library focused on providing solutions for common workloads |
Jraph | 1,375 | 8 months ago | Lightweight graph neural network library |
Neural Tangents | 2,278 | 9 months ago | High-level API for specifying neural networks of both finite and width |
HuggingFace | 135,022 | 6 days ago | Ecosystem of pretrained Transformers for a wide range of natural language tasks (Flax) |
Equinox | 2,118 | 21 days ago | Callable PyTrees and filtered JIT/grad transformations => neural networks in JAX |
Scenic | 3,332 | about 1 month ago | A Jax Library for Computer Vision Research and Beyond |
Awesome JAX / Libraries | |||
Levanter | 516 | 3 days ago | Legible, Scalable, Reproducible Foundation Models with Named Tensors and JAX |
EasyLM | 2,409 | 3 months ago | LLMs made easy: Pre-training, finetuning, evaluating and serving LLMs in JAX/Flax |
NumPyro | 2,290 | 4 days ago | Probabilistic programming based on the Pyro library |
Chex | 788 | 8 days ago | Utilities to write and test reliable JAX code |
Optax | 1,697 | 9 days ago | Gradient processing and optimization library |
RLax | 1,263 | about 2 months ago | Library for implementing reinforcement learning agents |
JAX, M.D. | 1,185 | 21 days ago | Accelerated, differential molecular dynamics |
Coax | 167 | almost 2 years ago | Turn RL papers into code, the easy way |
Distrax | 536 | 2 months ago | Reimplementation of TensorFlow Probability, containing probability distributions and bijectors |
cvxpylayers | 1,819 | 9 days ago | Construct differentiable convex optimization layers |
TensorLy | 1,566 | 9 days ago | Tensor learning made simple |
NetKet | 548 | 6 days ago | Machine Learning toolbox for Quantum Physics |
Fortuna | 893 | 27 days ago | AWS library for Uncertainty Quantification in Deep Learning |
BlackJAX | 846 | 22 days ago | Library of samplers for JAX |
Awesome JAX / Libraries / New Libraries | |||
Neural Network Libraries | |||
Awesome JAX / Libraries / New Libraries / Neural Network Libraries | |||
Federated learning in JAX, built on Optax and Haiku | |||
Equivariant MLP | 257 | over 1 year ago | Construct equivariant neural network layers |
jax-resnet | 104 | over 2 years ago | Implementations and checkpoints for ResNet variants in Flax |
Parallax | 155 | over 4 years ago | Immutable Torch Modules for JAX |
Awesome JAX / Libraries / New Libraries | |||
jax-unirep | 104 | 3 months ago | Library implementing the for protein machine learning applications |
jax-flows | 274 | over 1 year ago | Normalizing flows in JAX |
sklearn-jax-kernels | 42 | about 4 years ago | kernel matrices using JAX |
jax-cosmo | 178 | 4 months ago | Differentiable cosmology library |
efax | 55 | 4 days ago | Exponential Families in JAX |
mpi4jax | 445 | 17 days ago | Combine MPI operations with your Jax code on CPUs and GPUs |
imax | 37 | 8 months ago | Image augmentations and transformations |
FlaxVision | 44 | almost 4 years ago | Flax version of TorchVision |
Oryx | 4,269 | 6 days ago | Probabilistic programming language based on program transformations |
Optimal Transport Tools | 213 | almost 3 years ago | Toolbox that bundles utilities to solve optimal transport problems |
delta PV | 60 | almost 2 years ago | A photovoltaic simulator with automatic differentation |
jaxlie | 233 | 2 months ago | Lie theory library for rigid body transformations and optimization |
BRAX | 2,337 | 7 days ago | Differentiable physics engine to simulate environments along with learning algorithms to train agents for these environments |
flaxmodels | 238 | over 1 year ago | Pretrained models for Jax/Flax |
CR.Sparse | 88 | about 1 year ago | XLA accelerated algorithms for sparse representations and compressive sensing |
exojax | 57 | about 1 month ago | Automatic differentiable spectrum modeling of exoplanets/brown dwarfs compatible to JAX |
JAXopt | 933 | 2 months ago | Hardware accelerated (GPU/TPU), batchable and differentiable optimizers in JAX |
PIX | 389 | 13 days ago | PIX is an image processing library in JAX, for JAX |
bayex | 84 | 7 months ago | Bayesian Optimization powered by JAX |
JaxDF | 121 | 2 months ago | Framework for differentiable simulators with arbitrary discretizations |
tree-math | 188 | 6 months ago | Convert functions that operate on arrays into functions that operate on PyTrees |
jax-models | 151 | over 2 years ago | Implementations of research papers originally without code or code written with frameworks other than JAX |
PGMax | 64 | about 1 month ago | A framework for building discrete Probabilistic Graphical Models (PGM's) and running inference inference on them via JAX |
EvoJAX | 843 | 5 months ago | Hardware-Accelerated Neuroevolution |
evosax | 506 | about 1 month ago | JAX-Based Evolution Strategies |
SymJAX | 120 | over 1 year ago | Symbolic CPU/GPU/TPU programming |
mcx | 325 | 8 months ago | Express & compile probabilistic programs for performant inference |
Einshape | 99 | 5 months ago | DSL-based reshaping library for JAX and other frameworks |
ALX | 34,295 | 6 days ago | Open-source library for distributed matrix factorization using Alternating Least Squares, more info in |
Diffrax | 1,442 | 3 days ago | Numerical differential equation solvers in JAX |
tinygp | 296 | 11 days ago | The of Gaussian process libraries in JAX |
gymnax | 650 | 5 months ago | Reinforcement Learning Environments with the well-known gym API |
Mctx | 2,356 | 4 months ago | Monte Carlo tree search algorithms in native JAX |
KFAC-JAX | 248 | 5 days ago | Second Order Optimization with Approximate Curvature for NNs |
TF2JAX | 105 | 17 days ago | Convert functions/graphs to JAX functions |
jwave | 143 | 2 months ago | A library for differentiable acoustic simulations |
GPJax | 461 | 20 days ago | Gaussian processes in JAX |
Jumanji | 622 | 6 days ago | A Suite of Industry-Driven Hardware-Accelerated RL Environments written in JAX |
Eqxvision | 102 | 4 months ago | Equinox version of Torchvision |
JAXFit | 51 | over 1 year ago | Accelerated curve fitting library for nonlinear least-squares problems (see ) |
econpizza | 78 | 9 days ago | Solve macroeconomic models with hetereogeneous agents using JAX |
SPU | 242 | 3 days ago | A domain-specific compiler and runtime suite to run JAX code with MPC(Secure Multi-Party Computation) |
jax-tqdm | 96 | 28 days ago | Add a tqdm progress bar to JAX scans and loops |
safejax | 42 | 6 months ago | Serialize JAX, Flax, Haiku, or Objax model params with 🤗 |
Kernex | 66 | about 1 year ago | Differentiable stencil decorators in JAX |
MaxText | 1,529 | 4 days ago | A simple, performant and scalable Jax LLM written in pure Python/Jax and targeting Google Cloud TPUs |
Pax | 457 | 9 days ago | A Jax-based machine learning framework for training large scale models |
Praxis | 178 | 10 days ago | The layer library for Pax with a goal to be usable by other JAX-based ML projects |
purejaxrl | 722 | 2 months ago | Vectorisable, end-to-end RL algorithms in JAX |
Lorax | 132 | 9 months ago | Automatically apply LoRA to JAX models (Flax, Haiku, etc.) |
SCICO | 105 | 13 days ago | Scientific computational imaging in JAX |
Spyx | 101 | about 1 month ago | Spiking Neural Networks in JAX for machine learning on neuromorphic hardware |
BrainPy | 533 | 9 days ago | Brain Dynamics Programming in Python |
OTT-JAX | 525 | 3 days ago | Optimal transport tools in JAX |
QDax | 266 | about 1 month ago | Quality Diversity optimization in Jax |
JAX Toolbox | 245 | 3 days ago | Nightly CI and optimized examples for JAX on NVIDIA GPUs using libraries such as T5x, Paxml, and Transformer Engine |
Pgx | Vectorized board game environments for RL with an AlphaZero example | ||
EasyDeL | 206 | 5 days ago | EasyDeL 🔮 is an OpenSource Library to make your training faster and more Optimized With cool Options for training and serving (Llama, MPT, Mixtral, Falcon, etc) in JAX |
XLB | 229 | 17 days ago | A Differentiable Massively Parallel Lattice Boltzmann Library in Python for Physics-Based Machine Learning |
dynamiqs | 161 | 6 days ago | High-performance and differentiable simulations of quantum systems with JAX |
Awesome JAX / Models and Projects / JAX | |||
Official implementation of | |||
kalman-jax | 95 | over 1 year ago | Approximate inference for Markov (i.e., temporal) Gaussian processes using iterated Kalman filtering and smoothing |
jaxns | 147 | 8 days ago | Nested sampling in JAX |
Amortized Bayesian Optimization | 34,295 | 6 days ago | Code related to |
Accurate Quantized Training | 34,295 | 6 days ago | Tools and libraries for running and analyzing neural network quantization experiments in JAX and Flax |
BNN-HMC | 34,295 | 6 days ago | Implementation for the paper |
JAX-DFT | 34,295 | 6 days ago | One-dimensional density functional theory (DFT) in JAX, with implementation of |
Robust Loss | 34,295 | 6 days ago | Reference code for the paper |
Symbolic Functionals | 34,295 | 6 days ago | Demonstration from |
TriMap | 34,295 | 6 days ago | Official JAX implementation of |
Awesome JAX / Models and Projects / Flax | |||
Performer | 34,295 | 6 days ago | Flax implementation of the Performer (linear transformer via FAVOR+) architecture |
JaxNeRF | 34,295 | 6 days ago | Implementation of with multi-device GPU/TPU support |
mip-NeRF | 900 | about 2 years ago | Official implementation of |
RegNeRF | 34,295 | 6 days ago | Official implementation of |
Big Transfer (BiT) | 1,513 | 4 months ago | Implementation of |
JAX RL | 630 | about 2 years ago | Implementations of reinforcement learning algorithms |
gMLP | Implementation of | ||
MLP Mixer | Minimal implementation of | ||
Distributed Shampoo | 34,295 | 6 days ago | Implementation of |
NesT | 193 | 4 months ago | Official implementation of |
XMC-GAN | 98 | 27 days ago | Official implementation of |
FNet | 34,295 | 6 days ago | Official implementation of |
GFSA | 34,295 | 6 days ago | Official implementation of |
IPA-GNN | 34,295 | 6 days ago | Official implementation of |
Flax Models | 34,295 | 6 days ago | Collection of models and methods implemented in Flax |
Protein LM | 34,295 | 6 days ago | Implements BERT and autoregressive models for proteins, as described in and |
Slot Attention | 34,295 | 6 days ago | Reference implementation for |
Vision Transformer | 10,450 | 6 months ago | Official implementation of |
FID computation | 24 | 4 months ago | Port of to Flax |
ARDM | 34,295 | 6 days ago | Official implementation of |
D3PM | 34,295 | 6 days ago | Official implementation of |
Gumbel-max Causal Mechanisms | 34,295 | 6 days ago | Code for , with extra code in |
Latent Programmer | 34,295 | 6 days ago | Code for the ICML 2021 paper |
SNeRG | 34,295 | 6 days ago | Official implementation of |
Spin-weighted Spherical CNNs | 34,295 | 6 days ago | Adaptation of |
VDVAE | 34,295 | 6 days ago | Adaptation of , original code at |
MUSIQ | 34,295 | 6 days ago | Checkpoints and model inference code for the ICCV 2021 paper |
AQuaDem | 34,295 | 6 days ago | Official implementation of |
Combiner | 34,295 | 6 days ago | Official implementation of |
Dreamfields | 34,295 | 6 days ago | Official implementation of the ICLR 2022 paper |
GIFT | 34,295 | 6 days ago | Official implementation of |
Light Field Neural Rendering | 34,295 | 6 days ago | Official implementation of |
Sharpened Cosine Similarity in JAX by Raphael Pisoni | A JAX/Flax implementation of the Sharpened Cosine Similarity layer | ||
GNNs for Solving Combinatorial Optimization Problems | 42 | almost 2 years ago | A JAX + Flax implementation of |
Awesome JAX / Models and Projects / Haiku | |||
AlphaFold | 12,859 | 5 months ago | Implementation of the inference pipeline of AlphaFold v2.0, presented in |
Adversarial Robustness | 13,250 | 26 days ago | Reference code for and |
Bootstrap Your Own Latent | 13,250 | 26 days ago | Implementation for the paper |
Gated Linear Networks | 13,250 | 26 days ago | GLNs are a family of backpropagation-free neural networks |
Glassy Dynamics | 13,250 | 26 days ago | Open source implementation of the paper |
MMV | 13,250 | 26 days ago | Code for the models in |
Normalizer-Free Networks | 13,250 | 26 days ago | Official Haiku implementation of |
NuX | 82 | 12 months ago | Normalizing flows with JAX |
OGB-LSC | 13,250 | 26 days ago | This repository contains DeepMind's entry to the (quantum chemistry) and (academic graph) tracks of the (OGB-LSC) |
Persistent Evolution Strategies | 34,295 | 6 days ago | Code used for the paper |
Two Player Auction Learning | 0 | 12 months ago | JAX implementation of the paper |
WikiGraphs | 13,250 | 26 days ago | Baseline code to reproduce results in |
Awesome JAX / Models and Projects / Trax | |||
Reformer | 8,096 | 2 months ago | Implementation of the Reformer (efficient transformer) architecture |
Awesome JAX / Models and Projects / NumPyro | |||
lqg | 23 | 10 months ago | Official implementation of Bayesian inverse optimal control for linear-quadratic Gaussian problems from the paper |
Awesome JAX / Videos | |||
JAX, its use at DeepMind, and discussion between engineers, scientists, and JAX core team | |||
Introduction to JAX | Simple neural network from scratch in JAX | ||
JAX: Accelerated Machine Learning Research | SciPy 2020 | VanderPlas | JAX's core design, how it's powering new research, and how you can start using it | ||
Bayesian Programming with JAX + NumPyro — Andy Kitchen | Introduction to Bayesian modelling using NumPyro | ||
JAX: Accelerated machine-learning research via composable function transformations in Python | NeurIPS 2019 | Skye Wanderman-Milne | JAX intro presentation in workshop | ||
JAX on Cloud TPUs | NeurIPS 2020 | Skye Wanderman-Milne and James Bradbury | Presentation of TPU host access with demo | ||
Deep Implicit Layers - Neural ODEs, Deep Equilibirum Models, and Beyond | NeurIPS 2020 | Tutorial created by Zico Kolter, David Duvenaud, and Matt Johnson with Colab notebooks avaliable in | ||
Solving y=mx+b with Jax on a TPU Pod slice - Mat Kelcey | A four part YouTube tutorial series with Colab notebooks that starts with Jax fundamentals and moves up to training with a data parallel approach on a v3-32 TPU Pod slice | ||
JAX, Flax & Transformers 🤗 | 135,022 | 6 days ago | 3 days of talks around JAX / Flax, Transformers, large-scale language modeling and other great topics |
Awesome JAX / Papers | |||
Compiling machine learning programs via high-level tracing. Roy Frostig, Matthew James Johnson, Chris Leary. MLSys 2018. | White paper describing an early version of JAX, detailing how computation is traced and compiled | ||
JAX, M.D.: A Framework for Differentiable Physics. Samuel S. Schoenholz, Ekin D. Cubuk. NeurIPS 2020. | Introduces JAX, M.D., a differentiable physics library which includes simulation environments, interaction potentials, neural networks, and more | ||
Enabling Fast Differentially Private SGD via Just-in-Time Compilation and Vectorization. Pranav Subramani, Nicholas Vadivelu, Gautam Kamath. arXiv 2020. | Uses JAX's JIT and VMAP to achieve faster differentially private than existing libraries | ||
XLB: A Differentiable Massively Parallel Lattice Boltzmann Library in Python. Mohammadmehdi Ataei, Hesam Salehipour. arXiv 2023. | White paper describing the XLB library: benchmarks, validations, and more details about the library | ||
Awesome JAX / Tutorials and Blog Posts | |||
Describes the state of JAX and the JAX ecosystem at DeepMind | |||
Getting started with JAX (MLPs, CNNs & RNNs) by Robert Lange | Neural network building blocks from scratch with the basic JAX operators | ||
Learn JAX: From Linear Regression to Neural Networks by Rito Ghosh | A gentle introduction to JAX and using it to implement Linear and Logistic Regression, and Neural Network models and using them to solve real world problems | ||
Tutorial: image classification with JAX and Flax Linen by 8bitmp3 | 24 | over 1 year ago | Learn how to create a simple convolutional network with the Linen API by Flax and train it to recognize handwritten digits |
Plugging Into JAX by Nick Doiron | Compares Flax, Haiku, and Objax on the Kaggle flower classification challenge | ||
Meta-Learning in 50 Lines of JAX by Eric Jang | Introduction to both JAX and Meta-Learning | ||
Normalizing Flows in 100 Lines of JAX by Eric Jang | Concise implementation of | ||
Differentiable Path Tracing on the GPU/TPU by Eric Jang | Tutorial on implementing path tracing | ||
Ensemble networks by Mat Kelcey | Ensemble nets are a method of representing an ensemble of models as one single logical model | ||
Out of distribution (OOD) detection by Mat Kelcey | Implements different methods for OOD detection | ||
Understanding Autodiff with JAX by Srihari Radhakrishna | Understand how autodiff works using JAX | ||
From PyTorch to JAX: towards neural net frameworks that purify stateful code by Sabrina J. Mielke | Showcases how to go from a PyTorch-like style of coding to a more Functional-style of coding | ||
Extending JAX with custom C++ and CUDA code by Dan Foreman-Mackey | 378 | 3 months ago | Tutorial demonstrating the infrastructure required to provide custom ops in JAX |
Evolving Neural Networks in JAX by Robert Tjarko Lange | Explores how JAX can power the next generation of scalable neuroevolution algorithms | ||
Exploring hyperparameter meta-loss landscapes with JAX by Luke Metz | Demonstrates how to use JAX to perform inner-loss optimization with SGD and Momentum, outer-loss optimization with gradients, and outer-loss optimization using evolutionary strategies | ||
Deterministic ADVI in JAX by Martin Ingram | Walk through of implementing automatic differentiation variational inference (ADVI) easily and cleanly with JAX | ||
Evolved channel selection by Mat Kelcey | Trains a classification model robust to different combinations of input channels at different resolutions, then uses a genetic algorithm to decide the best combination for a particular loss | ||
Introduction to JAX by Kevin Murphy | Colab that introduces various aspects of the language and applies them to simple ML problems | ||
Writing an MCMC sampler in JAX by Jeremie Coullon | Tutorial on the different ways to write an MCMC sampler in JAX along with speed benchmarks | ||
How to add a progress bar to JAX scans and loops by Jeremie Coullon | Tutorial on how to add a progress bar to compiled loops in JAX using the module | ||
Get started with JAX by Aleksa Gordić | 661 | 12 months ago | A series of notebooks and videos going from zero JAX knowledge to building neural networks in Haiku |
Writing a Training Loop in JAX + FLAX by Saurav Maheshkar and Soumik Rakshit | A tutorial on writing a simple end-to-end training and evaluation pipeline in JAX, Flax and Optax | ||
Implementing NeRF in JAX by Soumik Rakshit and Saurav Maheshkar | A tutorial on 3D volumetric rendering of scenes represented by Neural Radiance Fields in JAX | ||
Deep Learning tutorials with JAX+Flax by Phillip Lippe | A series of notebooks explaining various deep learning concepts, from basics (e.g. intro to JAX/Flax, activiation functions) to recent advances (e.g., Vision Transformers, SimCLR), with translations to PyTorch | ||
Achieving 4000x Speedups with PureJaxRL | A blog post on how JAX can massively speedup RL training through vectorisation | ||
Awesome JAX / Books | |||
A hands-on guide to using JAX for deep learning and other mathematically-intensive applications | |||
Awesome JAX / Community | |||
JAX GitHub Discussions | 30,499 | 7 days ago | |