awesome-jax

JAX library catalog

A curated list of resources and libraries for JAX machine learning research

JAX - A curated list of resources https://github.com/google/jax

GitHub

2k stars
51 watching
134 forks
last commit: 7 months ago
Linked from 2 awesome lists

autogradawesomeawesome-listdeep-learningjaxmachine-learningneural-networknumpyxla

Awesome JAX / Libraries

Neural Network Libraries

Awesome JAX / Libraries / Neural Network Libraries

Centered on flexibility and clarity
Haiku 2,921 about 2 months ago Focused on simplicity, created by the authors of Sonnet at DeepMind
Objax 772 about 1 year ago Has an object oriented design similar to PyTorch
Elegy A High Level API for Deep Learning in JAX. Supports Flax, Haiku, and Optax
Trax 8,114 about 2 months ago "Batteries included" deep learning library focused on providing solutions for common workloads
Jraph 1,380 11 months ago Lightweight graph neural network library
Neural Tangents 2,291 11 months ago High-level API for specifying neural networks of both finite and width
HuggingFace 136,357 about 1 month ago Ecosystem of pretrained Transformers for a wide range of natural language tasks (Flax)
Equinox 2,157 about 2 months ago Callable PyTrees and filtered JIT/grad transformations => neural networks in JAX
Scenic 3,363 about 2 months ago A Jax Library for Computer Vision Research and Beyond

Awesome JAX / Libraries

Levanter 527 about 1 month ago Legible, Scalable, Reproducible Foundation Models with Named Tensors and JAX
EasyLM 2,428 6 months ago LLMs made easy: Pre-training, finetuning, evaluating and serving LLMs in JAX/Flax
NumPyro 2,341 about 2 months ago Probabilistic programming based on the Pyro library
Chex 797 about 2 months ago Utilities to write and test reliable JAX code
Optax 1,730 about 1 month ago Gradient processing and optimization library
RLax 1,272 4 months ago Library for implementing reinforcement learning agents
JAX, M.D. 1,204 2 months ago Accelerated, differential molecular dynamics
Coax 168 almost 2 years ago Turn RL papers into code, the easy way
Distrax 538 about 2 months ago Reimplementation of TensorFlow Probability, containing probability distributions and bijectors
cvxpylayers 1,843 about 2 months ago Construct differentiable convex optimization layers
TensorLy 1,576 about 1 month ago Tensor learning made simple
NetKet 554 about 2 months ago Machine Learning toolbox for Quantum Physics
Fortuna 892 2 months ago AWS library for Uncertainty Quantification in Deep Learning
BlackJAX 858 3 months ago Library of samplers for JAX

Awesome JAX / Libraries / New Libraries

Neural Network Libraries

Awesome JAX / Libraries / New Libraries / Neural Network Libraries

Federated learning in JAX, built on Optax and Haiku
Equivariant MLP 257 over 1 year ago Construct equivariant neural network layers
jax-resnet 105 over 2 years ago Implementations and checkpoints for ResNet variants in Flax
Parallax 155 over 4 years ago Immutable Torch Modules for JAX

Awesome JAX / Libraries / New Libraries

jax-unirep 104 5 months ago Library implementing the for protein machine learning applications
jax-flows 275 over 1 year ago Normalizing flows in JAX
sklearn-jax-kernels 42 over 4 years ago kernel matrices using JAX
jax-cosmo 182 6 months ago Differentiable cosmology library
efax 58 about 1 month ago Exponential Families in JAX
mpi4jax 453 about 2 months ago Combine MPI operations with your Jax code on CPUs and GPUs
imax 37 10 months ago Image augmentations and transformations
FlaxVision 44 about 4 years ago Flax version of TorchVision
Oryx 4,274 about 2 months ago Probabilistic programming language based on program transformations
Optimal Transport Tools 213 almost 3 years ago Toolbox that bundles utilities to solve optimal transport problems
delta PV 60 almost 2 years ago A photovoltaic simulator with automatic differentation
jaxlie 236 2 months ago Lie theory library for rigid body transformations and optimization
BRAX 2,397 about 2 months ago Differentiable physics engine to simulate environments along with learning algorithms to train agents for these environments
flaxmodels 240 over 1 year ago Pretrained models for Jax/Flax
CR.Sparse 88 over 1 year ago XLA accelerated algorithms for sparse representations and compressive sensing
exojax 57 about 2 months ago Automatic differentiable spectrum modeling of exoplanets/brown dwarfs compatible to JAX
JAXopt 941 4 months ago Hardware accelerated (GPU/TPU), batchable and differentiable optimizers in JAX
PIX 395 about 2 months ago PIX is an image processing library in JAX, for JAX
bayex 86 2 months ago Bayesian Optimization powered by JAX
JaxDF 124 4 months ago Framework for differentiable simulators with arbitrary discretizations
tree-math 194 about 2 months ago Convert functions that operate on arrays into functions that operate on PyTrees
jax-models 151 over 2 years ago Implementations of research papers originally without code or code written with frameworks other than JAX
PGMax 64 4 months ago A framework for building discrete Probabilistic Graphical Models (PGM's) and running inference inference on them via JAX
EvoJAX 856 7 months ago Hardware-Accelerated Neuroevolution
evosax 514 3 months ago JAX-Based Evolution Strategies
SymJAX 120 over 1 year ago Symbolic CPU/GPU/TPU programming
mcx 328 11 months ago Express & compile probabilistic programs for performant inference
Einshape 100 7 months ago DSL-based reshaping library for JAX and other frameworks
ALX 34,478 about 2 months ago Open-source library for distributed matrix factorization using Alternating Least Squares, more info in
Diffrax 1,480 about 2 months ago Numerical differential equation solvers in JAX
tinygp 297 about 2 months ago The of Gaussian process libraries in JAX
gymnax 669 7 months ago Reinforcement Learning Environments with the well-known gym API
Mctx 2,377 about 2 months ago Monte Carlo tree search algorithms in native JAX
KFAC-JAX 252 about 2 months ago Second Order Optimization with Approximate Curvature for NNs
TF2JAX 109 about 1 month ago Convert functions/graphs to JAX functions
jwave 145 4 months ago A library for differentiable acoustic simulations
GPJax 467 2 months ago Gaussian processes in JAX
Jumanji 657 about 2 months ago A Suite of Industry-Driven Hardware-Accelerated RL Environments written in JAX
Eqxvision 102 6 months ago Equinox version of Torchvision
JAXFit 53 over 1 year ago Accelerated curve fitting library for nonlinear least-squares problems (see )
econpizza 79 about 2 months ago Solve macroeconomic models with hetereogeneous agents using JAX
SPU 246 about 1 month ago A domain-specific compiler and runtime suite to run JAX code with MPC(Secure Multi-Party Computation)
jax-tqdm 100 3 months ago Add a tqdm progress bar to JAX scans and loops
safejax 42 8 months ago Serialize JAX, Flax, Haiku, or Objax model params with 🤗
Kernex 67 over 1 year ago Differentiable stencil decorators in JAX
MaxText 1,557 about 1 month ago A simple, performant and scalable Jax LLM written in pure Python/Jax and targeting Google Cloud TPUs
Pax 461 about 2 months ago A Jax-based machine learning framework for training large scale models
Praxis 178 about 1 month ago The layer library for Pax with a goal to be usable by other JAX-based ML projects
purejaxrl 755 5 months ago Vectorisable, end-to-end RL algorithms in JAX
Lorax 134 11 months ago Automatically apply LoRA to JAX models (Flax, Haiku, etc.)
SCICO 107 about 2 months ago Scientific computational imaging in JAX
Spyx 104 4 months ago Spiking Neural Networks in JAX for machine learning on neuromorphic hardware
BrainPy 541 about 1 month ago Brain Dynamics Programming in Python
OTT-JAX 550 about 2 months ago Optimal transport tools in JAX
QDax 270 4 months ago Quality Diversity optimization in Jax
JAX Toolbox 268 about 1 month ago Nightly CI and optimized examples for JAX on NVIDIA GPUs using libraries such as T5x, Paxml, and Transformer Engine
Pgx Vectorized board game environments for RL with an AlphaZero example
EasyDeL 212 about 1 month ago EasyDeL 🔮 is an OpenSource Library to make your training faster and more Optimized With cool Options for training and serving (Llama, MPT, Mixtral, Falcon, etc) in JAX
XLB 241 about 2 months ago A Differentiable Massively Parallel Lattice Boltzmann Library in Python for Physics-Based Machine Learning
dynamiqs 184 about 1 month ago High-performance and differentiable simulations of quantum systems with JAX

Awesome JAX / Models and Projects / JAX

Official implementation of
kalman-jax 97 over 1 year ago Approximate inference for Markov (i.e., temporal) Gaussian processes using iterated Kalman filtering and smoothing
jaxns 156 about 2 months ago Nested sampling in JAX
Amortized Bayesian Optimization 34,478 about 2 months ago Code related to
Accurate Quantized Training 34,478 about 2 months ago Tools and libraries for running and analyzing neural network quantization experiments in JAX and Flax
BNN-HMC 34,478 about 2 months ago Implementation for the paper
JAX-DFT 34,478 about 2 months ago One-dimensional density functional theory (DFT) in JAX, with implementation of
Robust Loss 34,478 about 2 months ago Reference code for the paper
Symbolic Functionals 34,478 about 2 months ago Demonstration from
TriMap 34,478 about 2 months ago Official JAX implementation of

Awesome JAX / Models and Projects / Flax

Performer 34,478 about 2 months ago Flax implementation of the Performer (linear transformer via FAVOR+) architecture
JaxNeRF 34,478 about 2 months ago Implementation of with multi-device GPU/TPU support
mip-NeRF 905 over 2 years ago Official implementation of
RegNeRF 34,478 about 2 months ago Official implementation of
Big Transfer (BiT) 1,516 6 months ago Implementation of
JAX RL 640 over 2 years ago Implementations of reinforcement learning algorithms
gMLP Implementation of
MLP Mixer Minimal implementation of
Distributed Shampoo 34,478 about 2 months ago Implementation of
NesT 195 6 months ago Official implementation of
XMC-GAN 98 3 months ago Official implementation of
FNet 34,478 about 2 months ago Official implementation of
GFSA 34,478 about 2 months ago Official implementation of
IPA-GNN 34,478 about 2 months ago Official implementation of
Flax Models 34,478 about 2 months ago Collection of models and methods implemented in Flax
Protein LM 34,478 about 2 months ago Implements BERT and autoregressive models for proteins, as described in and
Slot Attention 34,478 about 2 months ago Reference implementation for
Vision Transformer 10,620 about 2 months ago Official implementation of
FID computation 24 7 months ago Port of to Flax
ARDM 34,478 about 2 months ago Official implementation of
D3PM 34,478 about 2 months ago Official implementation of
Gumbel-max Causal Mechanisms 34,478 about 2 months ago Code for , with extra code in
Latent Programmer 34,478 about 2 months ago Code for the ICML 2021 paper
SNeRG 34,478 about 2 months ago Official implementation of
Spin-weighted Spherical CNNs 34,478 about 2 months ago Adaptation of
VDVAE 34,478 about 2 months ago Adaptation of , original code at
MUSIQ 34,478 about 2 months ago Checkpoints and model inference code for the ICCV 2021 paper
AQuaDem 34,478 about 2 months ago Official implementation of
Combiner 34,478 about 2 months ago Official implementation of
Dreamfields 34,478 about 2 months ago Official implementation of the ICLR 2022 paper
GIFT 34,478 about 2 months ago Official implementation of
Light Field Neural Rendering 34,478 about 2 months ago Official implementation of
Sharpened Cosine Similarity in JAX by Raphael Pisoni A JAX/Flax implementation of the Sharpened Cosine Similarity layer
GNNs for Solving Combinatorial Optimization Problems 43 almost 2 years ago A JAX + Flax implementation of

Awesome JAX / Models and Projects / Haiku

AlphaFold 12,997 2 months ago Implementation of the inference pipeline of AlphaFold v2.0, presented in
Adversarial Robustness 13,329 2 months ago Reference code for and
Bootstrap Your Own Latent 13,329 2 months ago Implementation for the paper
Gated Linear Networks 13,329 2 months ago GLNs are a family of backpropagation-free neural networks
Glassy Dynamics 13,329 2 months ago Open source implementation of the paper
MMV 13,329 2 months ago Code for the models in
Normalizer-Free Networks 13,329 2 months ago Official Haiku implementation of
NuX 82 about 1 year ago Normalizing flows with JAX
OGB-LSC 13,329 2 months ago This repository contains DeepMind's entry to the (quantum chemistry) and (academic graph) tracks of the (OGB-LSC)
Persistent Evolution Strategies 34,478 about 2 months ago Code used for the paper
Two Player Auction Learning 0 about 1 year ago JAX implementation of the paper
WikiGraphs 13,329 2 months ago Baseline code to reproduce results in

Awesome JAX / Models and Projects / Trax

Reformer 8,114 about 2 months ago Implementation of the Reformer (efficient transformer) architecture

Awesome JAX / Models and Projects / NumPyro

lqg 25 about 2 months ago Official implementation of Bayesian inverse optimal control for linear-quadratic Gaussian problems from the paper

Awesome JAX / Videos

JAX, its use at DeepMind, and discussion between engineers, scientists, and JAX core team
Introduction to JAX Simple neural network from scratch in JAX
JAX: Accelerated Machine Learning Research | SciPy 2020 | VanderPlas JAX's core design, how it's powering new research, and how you can start using it
Bayesian Programming with JAX + NumPyro — Andy Kitchen Introduction to Bayesian modelling using NumPyro
JAX: Accelerated machine-learning research via composable function transformations in Python | NeurIPS 2019 | Skye Wanderman-Milne JAX intro presentation in workshop
JAX on Cloud TPUs | NeurIPS 2020 | Skye Wanderman-Milne and James Bradbury Presentation of TPU host access with demo
Deep Implicit Layers - Neural ODEs, Deep Equilibirum Models, and Beyond | NeurIPS 2020 Tutorial created by Zico Kolter, David Duvenaud, and Matt Johnson with Colab notebooks avaliable in
Solving y=mx+b with Jax on a TPU Pod slice - Mat Kelcey A four part YouTube tutorial series with Colab notebooks that starts with Jax fundamentals and moves up to training with a data parallel approach on a v3-32 TPU Pod slice
JAX, Flax & Transformers 🤗 136,357 about 1 month ago 3 days of talks around JAX / Flax, Transformers, large-scale language modeling and other great topics

Awesome JAX / Papers

Compiling machine learning programs via high-level tracing. Roy Frostig, Matthew James Johnson, Chris Leary. MLSys 2018. White paper describing an early version of JAX, detailing how computation is traced and compiled
JAX, M.D.: A Framework for Differentiable Physics. Samuel S. Schoenholz, Ekin D. Cubuk. NeurIPS 2020. Introduces JAX, M.D., a differentiable physics library which includes simulation environments, interaction potentials, neural networks, and more
Enabling Fast Differentially Private SGD via Just-in-Time Compilation and Vectorization. Pranav Subramani, Nicholas Vadivelu, Gautam Kamath. arXiv 2020. Uses JAX's JIT and VMAP to achieve faster differentially private than existing libraries
XLB: A Differentiable Massively Parallel Lattice Boltzmann Library in Python. Mohammadmehdi Ataei, Hesam Salehipour. arXiv 2023. White paper describing the XLB library: benchmarks, validations, and more details about the library

Awesome JAX / Tutorials and Blog Posts

Describes the state of JAX and the JAX ecosystem at DeepMind
Getting started with JAX (MLPs, CNNs & RNNs) by Robert Lange Neural network building blocks from scratch with the basic JAX operators
Learn JAX: From Linear Regression to Neural Networks by Rito Ghosh A gentle introduction to JAX and using it to implement Linear and Logistic Regression, and Neural Network models and using them to solve real world problems
Tutorial: image classification with JAX and Flax Linen by 8bitmp3 24 over 1 year ago Learn how to create a simple convolutional network with the Linen API by Flax and train it to recognize handwritten digits
Plugging Into JAX by Nick Doiron Compares Flax, Haiku, and Objax on the Kaggle flower classification challenge
Meta-Learning in 50 Lines of JAX by Eric Jang Introduction to both JAX and Meta-Learning
Normalizing Flows in 100 Lines of JAX by Eric Jang Concise implementation of
Differentiable Path Tracing on the GPU/TPU by Eric Jang Tutorial on implementing path tracing
Ensemble networks by Mat Kelcey Ensemble nets are a method of representing an ensemble of models as one single logical model
Out of distribution (OOD) detection by Mat Kelcey Implements different methods for OOD detection
Understanding Autodiff with JAX by Srihari Radhakrishna Understand how autodiff works using JAX
From PyTorch to JAX: towards neural net frameworks that purify stateful code by Sabrina J. Mielke Showcases how to go from a PyTorch-like style of coding to a more Functional-style of coding
Extending JAX with custom C++ and CUDA code by Dan Foreman-Mackey 379 5 months ago Tutorial demonstrating the infrastructure required to provide custom ops in JAX
Evolving Neural Networks in JAX by Robert Tjarko Lange Explores how JAX can power the next generation of scalable neuroevolution algorithms
Exploring hyperparameter meta-loss landscapes with JAX by Luke Metz Demonstrates how to use JAX to perform inner-loss optimization with SGD and Momentum, outer-loss optimization with gradients, and outer-loss optimization using evolutionary strategies
Deterministic ADVI in JAX by Martin Ingram Walk through of implementing automatic differentiation variational inference (ADVI) easily and cleanly with JAX
Evolved channel selection by Mat Kelcey Trains a classification model robust to different combinations of input channels at different resolutions, then uses a genetic algorithm to decide the best combination for a particular loss
Introduction to JAX by Kevin Murphy Colab that introduces various aspects of the language and applies them to simple ML problems
Writing an MCMC sampler in JAX by Jeremie Coullon Tutorial on the different ways to write an MCMC sampler in JAX along with speed benchmarks
How to add a progress bar to JAX scans and loops by Jeremie Coullon Tutorial on how to add a progress bar to compiled loops in JAX using the module
Get started with JAX by Aleksa Gordić 670 about 1 year ago A series of notebooks and videos going from zero JAX knowledge to building neural networks in Haiku
Writing a Training Loop in JAX + FLAX by Saurav Maheshkar and Soumik Rakshit A tutorial on writing a simple end-to-end training and evaluation pipeline in JAX, Flax and Optax
Implementing NeRF in JAX by Soumik Rakshit and Saurav Maheshkar A tutorial on 3D volumetric rendering of scenes represented by Neural Radiance Fields in JAX
Deep Learning tutorials with JAX+Flax by Phillip Lippe A series of notebooks explaining various deep learning concepts, from basics (e.g. intro to JAX/Flax, activiation functions) to recent advances (e.g., Vision Transformers, SimCLR), with translations to PyTorch
Achieving 4000x Speedups with PureJaxRL A blog post on how JAX can massively speedup RL training through vectorisation

Awesome JAX / Books

A hands-on guide to using JAX for deep learning and other mathematically-intensive applications

Awesome JAX / Community

JAX GitHub Discussions 30,744 about 2 months ago
Reddit

Backlinks from these awesome lists:

More related projects: