jax

Acceleration tool

Accelerates numerical computing by automatically differentiating and compiling Python functions for high-performance execution on GPUs and TPUs.

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

GitHub

30k stars
337 watching
3k forks
Language: Python
last commit: 6 days ago
Linked from 6 awesome lists

jax

Backlinks from these awesome lists:

Related projects:

Repository Description Stars
dfm/extending-jax This project provides infrastructure to interface custom C++ and CUDA code with the JAX library for scientific computing 378
google-deepmind/tf2jax Converts TensorFlow functions to equivalent JAX Python functions. 105
gordicaleksa/get-started-with-jax A repository providing tutorials and resources to learn JAX, a popular alternative to PyTorch and TensorFlow for machine learning. 661
jaxgaussianprocesses/gpjax Provides a low-level interface to Gaussian process models in JAX for flexible extension and customisation 461
expectationmax/sklearn-jax-kernels A set of composable kernels for scikit-learn implemented in JAX to accelerate computation and gradient calculation. 42
google/jaxopt An open-source project providing hardware accelerated, batchable and differentiable optimizers in JAX for deep learning. 933
information-fusion-lab-umass/nux A library for building parametric models of complex distributions using normalizing flows and JAX. 82
mpi4jax/mpi4jax Enables JAX array communication over MPI without copying data 445
google/neural-tangents A high-level neural network API for defining and training complex hierarchical networks of finite or infinite width 2,278
hips/autograd Automatically computes derivatives of Python and NumPy code for optimization tasks 7,017
chriswaites/jax-flows A Python library implementing normalizing flows, a type of generative model used in machine learning. 274
numpy/numpy Provides support for large, multi-dimensional arrays and matrices, along with functions to manipulate them, as well as tools for integration with C/C++ code. 28,087
differentiableuniverseinitiative/jax_cosmo A Python library for differentiable cosmology using automatic differentiation on GPUs 178
patrick-kidger/diffrax Provides numerical differential equation solvers using autodifferentiable and GPU-capable JAX. 1,442
dipolar-quantum-gases/jaxfit A package for GPU/TPU accelerated nonlinear least-squares curve fitting using JAX 51