awesome-jax
JAX library catalog
A curated list of resources and libraries for JAX machine learning research
JAX - A curated list of resources https://github.com/google/jax
2k stars
51 watching
134 forks
last commit: 7 months ago
Linked from 2 awesome lists
autogradawesomeawesome-listdeep-learningjaxmachine-learningneural-networknumpyxla
Awesome JAX / Libraries | |||
Neural Network Libraries | |||
Awesome JAX / Libraries / Neural Network Libraries | |||
Centered on flexibility and clarity | |||
Haiku | 2,921 | about 2 months ago | Focused on simplicity, created by the authors of Sonnet at DeepMind |
Objax | 772 | about 1 year ago | Has an object oriented design similar to PyTorch |
Elegy | A High Level API for Deep Learning in JAX. Supports Flax, Haiku, and Optax | ||
Trax | 8,114 | about 2 months ago | "Batteries included" deep learning library focused on providing solutions for common workloads |
Jraph | 1,380 | 11 months ago | Lightweight graph neural network library |
Neural Tangents | 2,291 | 11 months ago | High-level API for specifying neural networks of both finite and width |
HuggingFace | 136,357 | about 1 month ago | Ecosystem of pretrained Transformers for a wide range of natural language tasks (Flax) |
Equinox | 2,157 | about 2 months ago | Callable PyTrees and filtered JIT/grad transformations => neural networks in JAX |
Scenic | 3,363 | about 2 months ago | A Jax Library for Computer Vision Research and Beyond |
Awesome JAX / Libraries | |||
Levanter | 527 | about 1 month ago | Legible, Scalable, Reproducible Foundation Models with Named Tensors and JAX |
EasyLM | 2,428 | 6 months ago | LLMs made easy: Pre-training, finetuning, evaluating and serving LLMs in JAX/Flax |
NumPyro | 2,341 | about 2 months ago | Probabilistic programming based on the Pyro library |
Chex | 797 | about 2 months ago | Utilities to write and test reliable JAX code |
Optax | 1,730 | about 1 month ago | Gradient processing and optimization library |
RLax | 1,272 | 4 months ago | Library for implementing reinforcement learning agents |
JAX, M.D. | 1,204 | 2 months ago | Accelerated, differential molecular dynamics |
Coax | 168 | almost 2 years ago | Turn RL papers into code, the easy way |
Distrax | 538 | about 2 months ago | Reimplementation of TensorFlow Probability, containing probability distributions and bijectors |
cvxpylayers | 1,843 | about 2 months ago | Construct differentiable convex optimization layers |
TensorLy | 1,576 | about 1 month ago | Tensor learning made simple |
NetKet | 554 | about 2 months ago | Machine Learning toolbox for Quantum Physics |
Fortuna | 892 | 2 months ago | AWS library for Uncertainty Quantification in Deep Learning |
BlackJAX | 858 | 3 months ago | Library of samplers for JAX |
Awesome JAX / Libraries / New Libraries | |||
Neural Network Libraries | |||
Awesome JAX / Libraries / New Libraries / Neural Network Libraries | |||
Federated learning in JAX, built on Optax and Haiku | |||
Equivariant MLP | 257 | over 1 year ago | Construct equivariant neural network layers |
jax-resnet | 105 | over 2 years ago | Implementations and checkpoints for ResNet variants in Flax |
Parallax | 155 | over 4 years ago | Immutable Torch Modules for JAX |
Awesome JAX / Libraries / New Libraries | |||
jax-unirep | 104 | 5 months ago | Library implementing the for protein machine learning applications |
jax-flows | 275 | over 1 year ago | Normalizing flows in JAX |
sklearn-jax-kernels | 42 | over 4 years ago | kernel matrices using JAX |
jax-cosmo | 182 | 6 months ago | Differentiable cosmology library |
efax | 58 | about 1 month ago | Exponential Families in JAX |
mpi4jax | 453 | about 2 months ago | Combine MPI operations with your Jax code on CPUs and GPUs |
imax | 37 | 10 months ago | Image augmentations and transformations |
FlaxVision | 44 | about 4 years ago | Flax version of TorchVision |
Oryx | 4,274 | about 2 months ago | Probabilistic programming language based on program transformations |
Optimal Transport Tools | 213 | almost 3 years ago | Toolbox that bundles utilities to solve optimal transport problems |
delta PV | 60 | almost 2 years ago | A photovoltaic simulator with automatic differentation |
jaxlie | 236 | 2 months ago | Lie theory library for rigid body transformations and optimization |
BRAX | 2,397 | about 2 months ago | Differentiable physics engine to simulate environments along with learning algorithms to train agents for these environments |
flaxmodels | 240 | over 1 year ago | Pretrained models for Jax/Flax |
CR.Sparse | 88 | over 1 year ago | XLA accelerated algorithms for sparse representations and compressive sensing |
exojax | 57 | about 2 months ago | Automatic differentiable spectrum modeling of exoplanets/brown dwarfs compatible to JAX |
JAXopt | 941 | 4 months ago | Hardware accelerated (GPU/TPU), batchable and differentiable optimizers in JAX |
PIX | 395 | about 2 months ago | PIX is an image processing library in JAX, for JAX |
bayex | 86 | 2 months ago | Bayesian Optimization powered by JAX |
JaxDF | 124 | 4 months ago | Framework for differentiable simulators with arbitrary discretizations |
tree-math | 194 | about 2 months ago | Convert functions that operate on arrays into functions that operate on PyTrees |
jax-models | 151 | over 2 years ago | Implementations of research papers originally without code or code written with frameworks other than JAX |
PGMax | 64 | 4 months ago | A framework for building discrete Probabilistic Graphical Models (PGM's) and running inference inference on them via JAX |
EvoJAX | 856 | 7 months ago | Hardware-Accelerated Neuroevolution |
evosax | 514 | 3 months ago | JAX-Based Evolution Strategies |
SymJAX | 120 | over 1 year ago | Symbolic CPU/GPU/TPU programming |
mcx | 328 | 11 months ago | Express & compile probabilistic programs for performant inference |
Einshape | 100 | 7 months ago | DSL-based reshaping library for JAX and other frameworks |
ALX | 34,478 | about 2 months ago | Open-source library for distributed matrix factorization using Alternating Least Squares, more info in |
Diffrax | 1,480 | about 2 months ago | Numerical differential equation solvers in JAX |
tinygp | 297 | about 2 months ago | The of Gaussian process libraries in JAX |
gymnax | 669 | 7 months ago | Reinforcement Learning Environments with the well-known gym API |
Mctx | 2,377 | about 2 months ago | Monte Carlo tree search algorithms in native JAX |
KFAC-JAX | 252 | about 2 months ago | Second Order Optimization with Approximate Curvature for NNs |
TF2JAX | 109 | about 1 month ago | Convert functions/graphs to JAX functions |
jwave | 145 | 4 months ago | A library for differentiable acoustic simulations |
GPJax | 467 | 2 months ago | Gaussian processes in JAX |
Jumanji | 657 | about 2 months ago | A Suite of Industry-Driven Hardware-Accelerated RL Environments written in JAX |
Eqxvision | 102 | 6 months ago | Equinox version of Torchvision |
JAXFit | 53 | over 1 year ago | Accelerated curve fitting library for nonlinear least-squares problems (see ) |
econpizza | 79 | about 2 months ago | Solve macroeconomic models with hetereogeneous agents using JAX |
SPU | 246 | about 1 month ago | A domain-specific compiler and runtime suite to run JAX code with MPC(Secure Multi-Party Computation) |
jax-tqdm | 100 | 3 months ago | Add a tqdm progress bar to JAX scans and loops |
safejax | 42 | 8 months ago | Serialize JAX, Flax, Haiku, or Objax model params with 🤗 |
Kernex | 67 | over 1 year ago | Differentiable stencil decorators in JAX |
MaxText | 1,557 | about 1 month ago | A simple, performant and scalable Jax LLM written in pure Python/Jax and targeting Google Cloud TPUs |
Pax | 461 | about 2 months ago | A Jax-based machine learning framework for training large scale models |
Praxis | 178 | about 1 month ago | The layer library for Pax with a goal to be usable by other JAX-based ML projects |
purejaxrl | 755 | 5 months ago | Vectorisable, end-to-end RL algorithms in JAX |
Lorax | 134 | 11 months ago | Automatically apply LoRA to JAX models (Flax, Haiku, etc.) |
SCICO | 107 | about 2 months ago | Scientific computational imaging in JAX |
Spyx | 104 | 4 months ago | Spiking Neural Networks in JAX for machine learning on neuromorphic hardware |
BrainPy | 541 | about 1 month ago | Brain Dynamics Programming in Python |
OTT-JAX | 550 | about 2 months ago | Optimal transport tools in JAX |
QDax | 270 | 4 months ago | Quality Diversity optimization in Jax |
JAX Toolbox | 268 | about 1 month ago | Nightly CI and optimized examples for JAX on NVIDIA GPUs using libraries such as T5x, Paxml, and Transformer Engine |
Pgx | Vectorized board game environments for RL with an AlphaZero example | ||
EasyDeL | 212 | about 1 month ago | EasyDeL 🔮 is an OpenSource Library to make your training faster and more Optimized With cool Options for training and serving (Llama, MPT, Mixtral, Falcon, etc) in JAX |
XLB | 241 | about 2 months ago | A Differentiable Massively Parallel Lattice Boltzmann Library in Python for Physics-Based Machine Learning |
dynamiqs | 184 | about 1 month ago | High-performance and differentiable simulations of quantum systems with JAX |
Awesome JAX / Models and Projects / JAX | |||
Official implementation of | |||
kalman-jax | 97 | over 1 year ago | Approximate inference for Markov (i.e., temporal) Gaussian processes using iterated Kalman filtering and smoothing |
jaxns | 156 | about 2 months ago | Nested sampling in JAX |
Amortized Bayesian Optimization | 34,478 | about 2 months ago | Code related to |
Accurate Quantized Training | 34,478 | about 2 months ago | Tools and libraries for running and analyzing neural network quantization experiments in JAX and Flax |
BNN-HMC | 34,478 | about 2 months ago | Implementation for the paper |
JAX-DFT | 34,478 | about 2 months ago | One-dimensional density functional theory (DFT) in JAX, with implementation of |
Robust Loss | 34,478 | about 2 months ago | Reference code for the paper |
Symbolic Functionals | 34,478 | about 2 months ago | Demonstration from |
TriMap | 34,478 | about 2 months ago | Official JAX implementation of |
Awesome JAX / Models and Projects / Flax | |||
Performer | 34,478 | about 2 months ago | Flax implementation of the Performer (linear transformer via FAVOR+) architecture |
JaxNeRF | 34,478 | about 2 months ago | Implementation of with multi-device GPU/TPU support |
mip-NeRF | 905 | over 2 years ago | Official implementation of |
RegNeRF | 34,478 | about 2 months ago | Official implementation of |
Big Transfer (BiT) | 1,516 | 6 months ago | Implementation of |
JAX RL | 640 | over 2 years ago | Implementations of reinforcement learning algorithms |
gMLP | Implementation of | ||
MLP Mixer | Minimal implementation of | ||
Distributed Shampoo | 34,478 | about 2 months ago | Implementation of |
NesT | 195 | 6 months ago | Official implementation of |
XMC-GAN | 98 | 3 months ago | Official implementation of |
FNet | 34,478 | about 2 months ago | Official implementation of |
GFSA | 34,478 | about 2 months ago | Official implementation of |
IPA-GNN | 34,478 | about 2 months ago | Official implementation of |
Flax Models | 34,478 | about 2 months ago | Collection of models and methods implemented in Flax |
Protein LM | 34,478 | about 2 months ago | Implements BERT and autoregressive models for proteins, as described in and |
Slot Attention | 34,478 | about 2 months ago | Reference implementation for |
Vision Transformer | 10,620 | about 2 months ago | Official implementation of |
FID computation | 24 | 7 months ago | Port of to Flax |
ARDM | 34,478 | about 2 months ago | Official implementation of |
D3PM | 34,478 | about 2 months ago | Official implementation of |
Gumbel-max Causal Mechanisms | 34,478 | about 2 months ago | Code for , with extra code in |
Latent Programmer | 34,478 | about 2 months ago | Code for the ICML 2021 paper |
SNeRG | 34,478 | about 2 months ago | Official implementation of |
Spin-weighted Spherical CNNs | 34,478 | about 2 months ago | Adaptation of |
VDVAE | 34,478 | about 2 months ago | Adaptation of , original code at |
MUSIQ | 34,478 | about 2 months ago | Checkpoints and model inference code for the ICCV 2021 paper |
AQuaDem | 34,478 | about 2 months ago | Official implementation of |
Combiner | 34,478 | about 2 months ago | Official implementation of |
Dreamfields | 34,478 | about 2 months ago | Official implementation of the ICLR 2022 paper |
GIFT | 34,478 | about 2 months ago | Official implementation of |
Light Field Neural Rendering | 34,478 | about 2 months ago | Official implementation of |
Sharpened Cosine Similarity in JAX by Raphael Pisoni | A JAX/Flax implementation of the Sharpened Cosine Similarity layer | ||
GNNs for Solving Combinatorial Optimization Problems | 43 | almost 2 years ago | A JAX + Flax implementation of |
Awesome JAX / Models and Projects / Haiku | |||
AlphaFold | 12,997 | 2 months ago | Implementation of the inference pipeline of AlphaFold v2.0, presented in |
Adversarial Robustness | 13,329 | 2 months ago | Reference code for and |
Bootstrap Your Own Latent | 13,329 | 2 months ago | Implementation for the paper |
Gated Linear Networks | 13,329 | 2 months ago | GLNs are a family of backpropagation-free neural networks |
Glassy Dynamics | 13,329 | 2 months ago | Open source implementation of the paper |
MMV | 13,329 | 2 months ago | Code for the models in |
Normalizer-Free Networks | 13,329 | 2 months ago | Official Haiku implementation of |
NuX | 82 | about 1 year ago | Normalizing flows with JAX |
OGB-LSC | 13,329 | 2 months ago | This repository contains DeepMind's entry to the (quantum chemistry) and (academic graph) tracks of the (OGB-LSC) |
Persistent Evolution Strategies | 34,478 | about 2 months ago | Code used for the paper |
Two Player Auction Learning | 0 | about 1 year ago | JAX implementation of the paper |
WikiGraphs | 13,329 | 2 months ago | Baseline code to reproduce results in |
Awesome JAX / Models and Projects / Trax | |||
Reformer | 8,114 | about 2 months ago | Implementation of the Reformer (efficient transformer) architecture |
Awesome JAX / Models and Projects / NumPyro | |||
lqg | 25 | about 2 months ago | Official implementation of Bayesian inverse optimal control for linear-quadratic Gaussian problems from the paper |
Awesome JAX / Videos | |||
JAX, its use at DeepMind, and discussion between engineers, scientists, and JAX core team | |||
Introduction to JAX | Simple neural network from scratch in JAX | ||
JAX: Accelerated Machine Learning Research | SciPy 2020 | VanderPlas | JAX's core design, how it's powering new research, and how you can start using it | ||
Bayesian Programming with JAX + NumPyro — Andy Kitchen | Introduction to Bayesian modelling using NumPyro | ||
JAX: Accelerated machine-learning research via composable function transformations in Python | NeurIPS 2019 | Skye Wanderman-Milne | JAX intro presentation in workshop | ||
JAX on Cloud TPUs | NeurIPS 2020 | Skye Wanderman-Milne and James Bradbury | Presentation of TPU host access with demo | ||
Deep Implicit Layers - Neural ODEs, Deep Equilibirum Models, and Beyond | NeurIPS 2020 | Tutorial created by Zico Kolter, David Duvenaud, and Matt Johnson with Colab notebooks avaliable in | ||
Solving y=mx+b with Jax on a TPU Pod slice - Mat Kelcey | A four part YouTube tutorial series with Colab notebooks that starts with Jax fundamentals and moves up to training with a data parallel approach on a v3-32 TPU Pod slice | ||
JAX, Flax & Transformers 🤗 | 136,357 | about 1 month ago | 3 days of talks around JAX / Flax, Transformers, large-scale language modeling and other great topics |
Awesome JAX / Papers | |||
Compiling machine learning programs via high-level tracing. Roy Frostig, Matthew James Johnson, Chris Leary. MLSys 2018. | White paper describing an early version of JAX, detailing how computation is traced and compiled | ||
JAX, M.D.: A Framework for Differentiable Physics. Samuel S. Schoenholz, Ekin D. Cubuk. NeurIPS 2020. | Introduces JAX, M.D., a differentiable physics library which includes simulation environments, interaction potentials, neural networks, and more | ||
Enabling Fast Differentially Private SGD via Just-in-Time Compilation and Vectorization. Pranav Subramani, Nicholas Vadivelu, Gautam Kamath. arXiv 2020. | Uses JAX's JIT and VMAP to achieve faster differentially private than existing libraries | ||
XLB: A Differentiable Massively Parallel Lattice Boltzmann Library in Python. Mohammadmehdi Ataei, Hesam Salehipour. arXiv 2023. | White paper describing the XLB library: benchmarks, validations, and more details about the library | ||
Awesome JAX / Tutorials and Blog Posts | |||
Describes the state of JAX and the JAX ecosystem at DeepMind | |||
Getting started with JAX (MLPs, CNNs & RNNs) by Robert Lange | Neural network building blocks from scratch with the basic JAX operators | ||
Learn JAX: From Linear Regression to Neural Networks by Rito Ghosh | A gentle introduction to JAX and using it to implement Linear and Logistic Regression, and Neural Network models and using them to solve real world problems | ||
Tutorial: image classification with JAX and Flax Linen by 8bitmp3 | 24 | over 1 year ago | Learn how to create a simple convolutional network with the Linen API by Flax and train it to recognize handwritten digits |
Plugging Into JAX by Nick Doiron | Compares Flax, Haiku, and Objax on the Kaggle flower classification challenge | ||
Meta-Learning in 50 Lines of JAX by Eric Jang | Introduction to both JAX and Meta-Learning | ||
Normalizing Flows in 100 Lines of JAX by Eric Jang | Concise implementation of | ||
Differentiable Path Tracing on the GPU/TPU by Eric Jang | Tutorial on implementing path tracing | ||
Ensemble networks by Mat Kelcey | Ensemble nets are a method of representing an ensemble of models as one single logical model | ||
Out of distribution (OOD) detection by Mat Kelcey | Implements different methods for OOD detection | ||
Understanding Autodiff with JAX by Srihari Radhakrishna | Understand how autodiff works using JAX | ||
From PyTorch to JAX: towards neural net frameworks that purify stateful code by Sabrina J. Mielke | Showcases how to go from a PyTorch-like style of coding to a more Functional-style of coding | ||
Extending JAX with custom C++ and CUDA code by Dan Foreman-Mackey | 379 | 5 months ago | Tutorial demonstrating the infrastructure required to provide custom ops in JAX |
Evolving Neural Networks in JAX by Robert Tjarko Lange | Explores how JAX can power the next generation of scalable neuroevolution algorithms | ||
Exploring hyperparameter meta-loss landscapes with JAX by Luke Metz | Demonstrates how to use JAX to perform inner-loss optimization with SGD and Momentum, outer-loss optimization with gradients, and outer-loss optimization using evolutionary strategies | ||
Deterministic ADVI in JAX by Martin Ingram | Walk through of implementing automatic differentiation variational inference (ADVI) easily and cleanly with JAX | ||
Evolved channel selection by Mat Kelcey | Trains a classification model robust to different combinations of input channels at different resolutions, then uses a genetic algorithm to decide the best combination for a particular loss | ||
Introduction to JAX by Kevin Murphy | Colab that introduces various aspects of the language and applies them to simple ML problems | ||
Writing an MCMC sampler in JAX by Jeremie Coullon | Tutorial on the different ways to write an MCMC sampler in JAX along with speed benchmarks | ||
How to add a progress bar to JAX scans and loops by Jeremie Coullon | Tutorial on how to add a progress bar to compiled loops in JAX using the module | ||
Get started with JAX by Aleksa Gordić | 670 | about 1 year ago | A series of notebooks and videos going from zero JAX knowledge to building neural networks in Haiku |
Writing a Training Loop in JAX + FLAX by Saurav Maheshkar and Soumik Rakshit | A tutorial on writing a simple end-to-end training and evaluation pipeline in JAX, Flax and Optax | ||
Implementing NeRF in JAX by Soumik Rakshit and Saurav Maheshkar | A tutorial on 3D volumetric rendering of scenes represented by Neural Radiance Fields in JAX | ||
Deep Learning tutorials with JAX+Flax by Phillip Lippe | A series of notebooks explaining various deep learning concepts, from basics (e.g. intro to JAX/Flax, activiation functions) to recent advances (e.g., Vision Transformers, SimCLR), with translations to PyTorch | ||
Achieving 4000x Speedups with PureJaxRL | A blog post on how JAX can massively speedup RL training through vectorisation | ||
Awesome JAX / Books | |||
A hands-on guide to using JAX for deep learning and other mathematically-intensive applications | |||
Awesome JAX / Community | |||
JAX GitHub Discussions | 30,744 | about 2 months ago | |