jax

Numerical computing library

A library that provides high-performance numerical computing and machine learning capabilities.

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

GitHub

31k stars
336 watching
3k forks
Language: Python
last commit: about 1 month ago
Linked from 6 awesome lists

jax

Backlinks from these awesome lists:

Related projects:

Repository Description Stars
dfm/extending-jax This project provides infrastructure to interface custom C++ and CUDA code with the JAX library for scientific computing 379
google-deepmind/tf2jax Converts TensorFlow functions to equivalent JAX Python functions. 109
gordicaleksa/get-started-with-jax A repository providing tutorials and resources to learn JAX, a popular alternative to PyTorch and TensorFlow for machine learning. 670
jaxgaussianprocesses/gpjax Provides a low-level interface to Gaussian process models in JAX for flexible extension and customisation 467
expectationmax/sklearn-jax-kernels A set of composable kernels for scikit-learn implemented in JAX to accelerate computation and gradient calculation. 42
google/jaxopt An open-source project providing hardware accelerated, batchable and differentiable optimizers in JAX for deep learning. 941
information-fusion-lab-umass/nux A library for building parametric models of complex distributions using normalizing flows and JAX. 82
mpi4jax/mpi4jax Enables JAX array communication over MPI without copying data 453
google/neural-tangents A high-level neural network API for defining and training complex hierarchical networks of finite or infinite width 2,291
hips/autograd Automatically computes derivatives of Python and NumPy code for optimization tasks 7,049
chriswaites/jax-flows A Python library implementing normalizing flows, a type of generative model used in machine learning. 275
numpy/numpy A comprehensive library providing efficient numerical computation and data manipulation capabilities for Python-based scientific computing. 28,350
differentiableuniverseinitiative/jax_cosmo A Python library for differentiable cosmology using automatic differentiation on GPUs 182
patrick-kidger/diffrax Provides numerical differential equation solvers using autodifferentiable and GPU-capable JAX. 1,480
dipolar-quantum-gases/jaxfit A package for GPU/TPU accelerated nonlinear least-squares curve fitting using JAX 53