awesome-jax
JAX library catalog
A curated list of resources and libraries for JAX machine learning research
JAX - A curated list of resources https://github.com/google/jax
2k stars
51 watching
134 forks
last commit: over 1 year ago
Linked from 2 awesome lists
autogradawesomeawesome-listdeep-learningjaxmachine-learningneural-networknumpyxla
Awesome JAX / Libraries | |||
| Neural Network Libraries | |||
Awesome JAX / Libraries / Neural Network Libraries | |||
| Centered on flexibility and clarity | |||
| Haiku | 2,921 | 11 months ago | Focused on simplicity, created by the authors of Sonnet at DeepMind |
| Objax | 772 | almost 2 years ago | Has an object oriented design similar to PyTorch |
| Elegy | A High Level API for Deep Learning in JAX. Supports Flax, Haiku, and Optax | ||
| Trax | 8,114 | 11 months ago | "Batteries included" deep learning library focused on providing solutions for common workloads |
| Jraph | 1,380 | over 1 year ago | Lightweight graph neural network library |
| Neural Tangents | 2,291 | over 1 year ago | High-level API for specifying neural networks of both finite and width |
| HuggingFace | 136,357 | 11 months ago | Ecosystem of pretrained Transformers for a wide range of natural language tasks (Flax) |
| Equinox | 2,157 | 11 months ago | Callable PyTrees and filtered JIT/grad transformations => neural networks in JAX |
| Scenic | 3,363 | 11 months ago | A Jax Library for Computer Vision Research and Beyond |
Awesome JAX / Libraries | |||
| Levanter | 527 | 11 months ago | Legible, Scalable, Reproducible Foundation Models with Named Tensors and JAX |
| EasyLM | 2,428 | about 1 year ago | LLMs made easy: Pre-training, finetuning, evaluating and serving LLMs in JAX/Flax |
| NumPyro | 2,341 | 11 months ago | Probabilistic programming based on the Pyro library |
| Chex | 797 | 11 months ago | Utilities to write and test reliable JAX code |
| Optax | 1,730 | 11 months ago | Gradient processing and optimization library |
| RLax | 1,272 | about 1 year ago | Library for implementing reinforcement learning agents |
| JAX, M.D. | 1,204 | 12 months ago | Accelerated, differential molecular dynamics |
| Coax | 168 | almost 3 years ago | Turn RL papers into code, the easy way |
| Distrax | 538 | 11 months ago | Reimplementation of TensorFlow Probability, containing probability distributions and bijectors |
| cvxpylayers | 1,843 | 11 months ago | Construct differentiable convex optimization layers |
| TensorLy | 1,576 | 11 months ago | Tensor learning made simple |
| NetKet | 554 | 11 months ago | Machine Learning toolbox for Quantum Physics |
| Fortuna | 892 | 12 months ago | AWS library for Uncertainty Quantification in Deep Learning |
| BlackJAX | 858 | about 1 year ago | Library of samplers for JAX |
Awesome JAX / Libraries / New Libraries | |||
| Neural Network Libraries | |||
Awesome JAX / Libraries / New Libraries / Neural Network Libraries | |||
| Federated learning in JAX, built on Optax and Haiku | |||
| Equivariant MLP | 257 | over 2 years ago | Construct equivariant neural network layers |
| jax-resnet | 105 | over 3 years ago | Implementations and checkpoints for ResNet variants in Flax |
| Parallax | 155 | over 5 years ago | Immutable Torch Modules for JAX |
Awesome JAX / Libraries / New Libraries | |||
| jax-unirep | 104 | about 1 year ago | Library implementing the for protein machine learning applications |
| jax-flows | 275 | over 2 years ago | Normalizing flows in JAX |
| sklearn-jax-kernels | 42 | about 5 years ago | kernel matrices using JAX |
| jax-cosmo | 182 | over 1 year ago | Differentiable cosmology library |
| efax | 58 | 11 months ago | Exponential Families in JAX |
| mpi4jax | 453 | 11 months ago | Combine MPI operations with your Jax code on CPUs and GPUs |
| imax | 37 | over 1 year ago | Image augmentations and transformations |
| FlaxVision | 44 | almost 5 years ago | Flax version of TorchVision |
| Oryx | 4,274 | 11 months ago | Probabilistic programming language based on program transformations |
| Optimal Transport Tools | 213 | almost 4 years ago | Toolbox that bundles utilities to solve optimal transport problems |
| delta PV | 60 | almost 3 years ago | A photovoltaic simulator with automatic differentation |
| jaxlie | 236 | 11 months ago | Lie theory library for rigid body transformations and optimization |
| BRAX | 2,397 | 11 months ago | Differentiable physics engine to simulate environments along with learning algorithms to train agents for these environments |
| flaxmodels | 240 | about 2 years ago | Pretrained models for Jax/Flax |
| CR.Sparse | 88 | about 2 years ago | XLA accelerated algorithms for sparse representations and compressive sensing |
| exojax | 57 | 11 months ago | Automatic differentiable spectrum modeling of exoplanets/brown dwarfs compatible to JAX |
| JAXopt | 941 | about 1 year ago | Hardware accelerated (GPU/TPU), batchable and differentiable optimizers in JAX |
| PIX | 395 | 11 months ago | PIX is an image processing library in JAX, for JAX |
| bayex | 86 | 12 months ago | Bayesian Optimization powered by JAX |
| JaxDF | 124 | about 1 year ago | Framework for differentiable simulators with arbitrary discretizations |
| tree-math | 194 | 11 months ago | Convert functions that operate on arrays into functions that operate on PyTrees |
| jax-models | 151 | over 3 years ago | Implementations of research papers originally without code or code written with frameworks other than JAX |
| PGMax | 64 | about 1 year ago | A framework for building discrete Probabilistic Graphical Models (PGM's) and running inference inference on them via JAX |
| EvoJAX | 856 | over 1 year ago | Hardware-Accelerated Neuroevolution |
| evosax | 514 | about 1 year ago | JAX-Based Evolution Strategies |
| SymJAX | 120 | over 2 years ago | Symbolic CPU/GPU/TPU programming |
| mcx | 328 | over 1 year ago | Express & compile probabilistic programs for performant inference |
| Einshape | 100 | over 1 year ago | DSL-based reshaping library for JAX and other frameworks |
| ALX | 34,478 | 11 months ago | Open-source library for distributed matrix factorization using Alternating Least Squares, more info in |
| Diffrax | 1,480 | 11 months ago | Numerical differential equation solvers in JAX |
| tinygp | 297 | 11 months ago | The of Gaussian process libraries in JAX |
| gymnax | 669 | over 1 year ago | Reinforcement Learning Environments with the well-known gym API |
| Mctx | 2,377 | 11 months ago | Monte Carlo tree search algorithms in native JAX |
| KFAC-JAX | 252 | 11 months ago | Second Order Optimization with Approximate Curvature for NNs |
| TF2JAX | 109 | 11 months ago | Convert functions/graphs to JAX functions |
| jwave | 145 | about 1 year ago | A library for differentiable acoustic simulations |
| GPJax | 467 | 12 months ago | Gaussian processes in JAX |
| Jumanji | 657 | 11 months ago | A Suite of Industry-Driven Hardware-Accelerated RL Environments written in JAX |
| Eqxvision | 102 | over 1 year ago | Equinox version of Torchvision |
| JAXFit | 53 | over 2 years ago | Accelerated curve fitting library for nonlinear least-squares problems (see ) |
| econpizza | 79 | 11 months ago | Solve macroeconomic models with hetereogeneous agents using JAX |
| SPU | 246 | 11 months ago | A domain-specific compiler and runtime suite to run JAX code with MPC(Secure Multi-Party Computation) |
| jax-tqdm | 100 | about 1 year ago | Add a tqdm progress bar to JAX scans and loops |
| safejax | 42 | over 1 year ago | Serialize JAX, Flax, Haiku, or Objax model params with 🤗 |
| Kernex | 67 | about 2 years ago | Differentiable stencil decorators in JAX |
| MaxText | 1,557 | 11 months ago | A simple, performant and scalable Jax LLM written in pure Python/Jax and targeting Google Cloud TPUs |
| Pax | 461 | 11 months ago | A Jax-based machine learning framework for training large scale models |
| Praxis | 178 | 11 months ago | The layer library for Pax with a goal to be usable by other JAX-based ML projects |
| purejaxrl | 755 | about 1 year ago | Vectorisable, end-to-end RL algorithms in JAX |
| Lorax | 134 | over 1 year ago | Automatically apply LoRA to JAX models (Flax, Haiku, etc.) |
| SCICO | 107 | 11 months ago | Scientific computational imaging in JAX |
| Spyx | 104 | about 1 year ago | Spiking Neural Networks in JAX for machine learning on neuromorphic hardware |
| BrainPy | 541 | 11 months ago | Brain Dynamics Programming in Python |
| OTT-JAX | 550 | 11 months ago | Optimal transport tools in JAX |
| QDax | 270 | about 1 year ago | Quality Diversity optimization in Jax |
| JAX Toolbox | 268 | 11 months ago | Nightly CI and optimized examples for JAX on NVIDIA GPUs using libraries such as T5x, Paxml, and Transformer Engine |
| Pgx | Vectorized board game environments for RL with an AlphaZero example | ||
| EasyDeL | 212 | 11 months ago | EasyDeL 🔮 is an OpenSource Library to make your training faster and more Optimized With cool Options for training and serving (Llama, MPT, Mixtral, Falcon, etc) in JAX |
| XLB | 241 | 11 months ago | A Differentiable Massively Parallel Lattice Boltzmann Library in Python for Physics-Based Machine Learning |
| dynamiqs | 184 | 11 months ago | High-performance and differentiable simulations of quantum systems with JAX |
Awesome JAX / Models and Projects / JAX | |||
| Official implementation of | |||
| kalman-jax | 97 | over 2 years ago | Approximate inference for Markov (i.e., temporal) Gaussian processes using iterated Kalman filtering and smoothing |
| jaxns | 156 | 11 months ago | Nested sampling in JAX |
| Amortized Bayesian Optimization | 34,478 | 11 months ago | Code related to |
| Accurate Quantized Training | 34,478 | 11 months ago | Tools and libraries for running and analyzing neural network quantization experiments in JAX and Flax |
| BNN-HMC | 34,478 | 11 months ago | Implementation for the paper |
| JAX-DFT | 34,478 | 11 months ago | One-dimensional density functional theory (DFT) in JAX, with implementation of |
| Robust Loss | 34,478 | 11 months ago | Reference code for the paper |
| Symbolic Functionals | 34,478 | 11 months ago | Demonstration from |
| TriMap | 34,478 | 11 months ago | Official JAX implementation of |
Awesome JAX / Models and Projects / Flax | |||
| Performer | 34,478 | 11 months ago | Flax implementation of the Performer (linear transformer via FAVOR+) architecture |
| JaxNeRF | 34,478 | 11 months ago | Implementation of with multi-device GPU/TPU support |
| mip-NeRF | 905 | about 3 years ago | Official implementation of |
| RegNeRF | 34,478 | 11 months ago | Official implementation of |
| Big Transfer (BiT) | 1,516 | over 1 year ago | Implementation of |
| JAX RL | 640 | about 3 years ago | Implementations of reinforcement learning algorithms |
| gMLP | Implementation of | ||
| MLP Mixer | Minimal implementation of | ||
| Distributed Shampoo | 34,478 | 11 months ago | Implementation of |
| NesT | 195 | over 1 year ago | Official implementation of |
| XMC-GAN | 98 | about 1 year ago | Official implementation of |
| FNet | 34,478 | 11 months ago | Official implementation of |
| GFSA | 34,478 | 11 months ago | Official implementation of |
| IPA-GNN | 34,478 | 11 months ago | Official implementation of |
| Flax Models | 34,478 | 11 months ago | Collection of models and methods implemented in Flax |
| Protein LM | 34,478 | 11 months ago | Implements BERT and autoregressive models for proteins, as described in and |
| Slot Attention | 34,478 | 11 months ago | Reference implementation for |
| Vision Transformer | 10,620 | 11 months ago | Official implementation of |
| FID computation | 24 | over 1 year ago | Port of to Flax |
| ARDM | 34,478 | 11 months ago | Official implementation of |
| D3PM | 34,478 | 11 months ago | Official implementation of |
| Gumbel-max Causal Mechanisms | 34,478 | 11 months ago | Code for , with extra code in |
| Latent Programmer | 34,478 | 11 months ago | Code for the ICML 2021 paper |
| SNeRG | 34,478 | 11 months ago | Official implementation of |
| Spin-weighted Spherical CNNs | 34,478 | 11 months ago | Adaptation of |
| VDVAE | 34,478 | 11 months ago | Adaptation of , original code at |
| MUSIQ | 34,478 | 11 months ago | Checkpoints and model inference code for the ICCV 2021 paper |
| AQuaDem | 34,478 | 11 months ago | Official implementation of |
| Combiner | 34,478 | 11 months ago | Official implementation of |
| Dreamfields | 34,478 | 11 months ago | Official implementation of the ICLR 2022 paper |
| GIFT | 34,478 | 11 months ago | Official implementation of |
| Light Field Neural Rendering | 34,478 | 11 months ago | Official implementation of |
| Sharpened Cosine Similarity in JAX by Raphael Pisoni | A JAX/Flax implementation of the Sharpened Cosine Similarity layer | ||
| GNNs for Solving Combinatorial Optimization Problems | 43 | over 2 years ago | A JAX + Flax implementation of |
Awesome JAX / Models and Projects / Haiku | |||
| AlphaFold | 12,997 | 12 months ago | Implementation of the inference pipeline of AlphaFold v2.0, presented in |
| Adversarial Robustness | 13,329 | 12 months ago | Reference code for and |
| Bootstrap Your Own Latent | 13,329 | 12 months ago | Implementation for the paper |
| Gated Linear Networks | 13,329 | 12 months ago | GLNs are a family of backpropagation-free neural networks |
| Glassy Dynamics | 13,329 | 12 months ago | Open source implementation of the paper |
| MMV | 13,329 | 12 months ago | Code for the models in |
| Normalizer-Free Networks | 13,329 | 12 months ago | Official Haiku implementation of |
| NuX | 82 | almost 2 years ago | Normalizing flows with JAX |
| OGB-LSC | 13,329 | 12 months ago | This repository contains DeepMind's entry to the (quantum chemistry) and (academic graph) tracks of the (OGB-LSC) |
| Persistent Evolution Strategies | 34,478 | 11 months ago | Code used for the paper |
| Two Player Auction Learning | 0 | almost 2 years ago | JAX implementation of the paper |
| WikiGraphs | 13,329 | 12 months ago | Baseline code to reproduce results in |
Awesome JAX / Models and Projects / Trax | |||
| Reformer | 8,114 | 11 months ago | Implementation of the Reformer (efficient transformer) architecture |
Awesome JAX / Models and Projects / NumPyro | |||
| lqg | 25 | 11 months ago | Official implementation of Bayesian inverse optimal control for linear-quadratic Gaussian problems from the paper |
Awesome JAX / Videos | |||
| JAX, its use at DeepMind, and discussion between engineers, scientists, and JAX core team | |||
| Introduction to JAX | Simple neural network from scratch in JAX | ||
| JAX: Accelerated Machine Learning Research | SciPy 2020 | VanderPlas | JAX's core design, how it's powering new research, and how you can start using it | ||
| Bayesian Programming with JAX + NumPyro — Andy Kitchen | Introduction to Bayesian modelling using NumPyro | ||
| JAX: Accelerated machine-learning research via composable function transformations in Python | NeurIPS 2019 | Skye Wanderman-Milne | JAX intro presentation in workshop | ||
| JAX on Cloud TPUs | NeurIPS 2020 | Skye Wanderman-Milne and James Bradbury | Presentation of TPU host access with demo | ||
| Deep Implicit Layers - Neural ODEs, Deep Equilibirum Models, and Beyond | NeurIPS 2020 | Tutorial created by Zico Kolter, David Duvenaud, and Matt Johnson with Colab notebooks avaliable in | ||
| Solving y=mx+b with Jax on a TPU Pod slice - Mat Kelcey | A four part YouTube tutorial series with Colab notebooks that starts with Jax fundamentals and moves up to training with a data parallel approach on a v3-32 TPU Pod slice | ||
| JAX, Flax & Transformers 🤗 | 136,357 | 11 months ago | 3 days of talks around JAX / Flax, Transformers, large-scale language modeling and other great topics |
Awesome JAX / Papers | |||
| Compiling machine learning programs via high-level tracing. Roy Frostig, Matthew James Johnson, Chris Leary. MLSys 2018. | White paper describing an early version of JAX, detailing how computation is traced and compiled | ||
| JAX, M.D.: A Framework for Differentiable Physics. Samuel S. Schoenholz, Ekin D. Cubuk. NeurIPS 2020. | Introduces JAX, M.D., a differentiable physics library which includes simulation environments, interaction potentials, neural networks, and more | ||
| Enabling Fast Differentially Private SGD via Just-in-Time Compilation and Vectorization. Pranav Subramani, Nicholas Vadivelu, Gautam Kamath. arXiv 2020. | Uses JAX's JIT and VMAP to achieve faster differentially private than existing libraries | ||
| XLB: A Differentiable Massively Parallel Lattice Boltzmann Library in Python. Mohammadmehdi Ataei, Hesam Salehipour. arXiv 2023. | White paper describing the XLB library: benchmarks, validations, and more details about the library | ||
Awesome JAX / Tutorials and Blog Posts | |||
| Describes the state of JAX and the JAX ecosystem at DeepMind | |||
| Getting started with JAX (MLPs, CNNs & RNNs) by Robert Lange | Neural network building blocks from scratch with the basic JAX operators | ||
| Learn JAX: From Linear Regression to Neural Networks by Rito Ghosh | A gentle introduction to JAX and using it to implement Linear and Logistic Regression, and Neural Network models and using them to solve real world problems | ||
| Tutorial: image classification with JAX and Flax Linen by 8bitmp3 | 24 | about 2 years ago | Learn how to create a simple convolutional network with the Linen API by Flax and train it to recognize handwritten digits |
| Plugging Into JAX by Nick Doiron | Compares Flax, Haiku, and Objax on the Kaggle flower classification challenge | ||
| Meta-Learning in 50 Lines of JAX by Eric Jang | Introduction to both JAX and Meta-Learning | ||
| Normalizing Flows in 100 Lines of JAX by Eric Jang | Concise implementation of | ||
| Differentiable Path Tracing on the GPU/TPU by Eric Jang | Tutorial on implementing path tracing | ||
| Ensemble networks by Mat Kelcey | Ensemble nets are a method of representing an ensemble of models as one single logical model | ||
| Out of distribution (OOD) detection by Mat Kelcey | Implements different methods for OOD detection | ||
| Understanding Autodiff with JAX by Srihari Radhakrishna | Understand how autodiff works using JAX | ||
| From PyTorch to JAX: towards neural net frameworks that purify stateful code by Sabrina J. Mielke | Showcases how to go from a PyTorch-like style of coding to a more Functional-style of coding | ||
| Extending JAX with custom C++ and CUDA code by Dan Foreman-Mackey | 379 | about 1 year ago | Tutorial demonstrating the infrastructure required to provide custom ops in JAX |
| Evolving Neural Networks in JAX by Robert Tjarko Lange | Explores how JAX can power the next generation of scalable neuroevolution algorithms | ||
| Exploring hyperparameter meta-loss landscapes with JAX by Luke Metz | Demonstrates how to use JAX to perform inner-loss optimization with SGD and Momentum, outer-loss optimization with gradients, and outer-loss optimization using evolutionary strategies | ||
| Deterministic ADVI in JAX by Martin Ingram | Walk through of implementing automatic differentiation variational inference (ADVI) easily and cleanly with JAX | ||
| Evolved channel selection by Mat Kelcey | Trains a classification model robust to different combinations of input channels at different resolutions, then uses a genetic algorithm to decide the best combination for a particular loss | ||
| Introduction to JAX by Kevin Murphy | Colab that introduces various aspects of the language and applies them to simple ML problems | ||
| Writing an MCMC sampler in JAX by Jeremie Coullon | Tutorial on the different ways to write an MCMC sampler in JAX along with speed benchmarks | ||
| How to add a progress bar to JAX scans and loops by Jeremie Coullon | Tutorial on how to add a progress bar to compiled loops in JAX using the module | ||
| Get started with JAX by Aleksa Gordić | 670 | almost 2 years ago | A series of notebooks and videos going from zero JAX knowledge to building neural networks in Haiku |
| Writing a Training Loop in JAX + FLAX by Saurav Maheshkar and Soumik Rakshit | A tutorial on writing a simple end-to-end training and evaluation pipeline in JAX, Flax and Optax | ||
| Implementing NeRF in JAX by Soumik Rakshit and Saurav Maheshkar | A tutorial on 3D volumetric rendering of scenes represented by Neural Radiance Fields in JAX | ||
| Deep Learning tutorials with JAX+Flax by Phillip Lippe | A series of notebooks explaining various deep learning concepts, from basics (e.g. intro to JAX/Flax, activiation functions) to recent advances (e.g., Vision Transformers, SimCLR), with translations to PyTorch | ||
| Achieving 4000x Speedups with PureJaxRL | A blog post on how JAX can massively speedup RL training through vectorisation | ||
Awesome JAX / Books | |||
| A hands-on guide to using JAX for deep learning and other mathematically-intensive applications | |||
Awesome JAX / Community | |||
| JAX GitHub Discussions | 30,744 | 11 months ago | |