jax
Acceleration tool
Accelerates numerical computing by automatically differentiating and compiling Python functions for high-performance execution on GPUs and TPUs.
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
30k stars
337 watching
3k forks
Language: Python
last commit: 6 days ago
Linked from 6 awesome lists
jax
Related projects:
Repository | Description | Stars |
---|---|---|
dfm/extending-jax | This project provides infrastructure to interface custom C++ and CUDA code with the JAX library for scientific computing | 378 |
google-deepmind/tf2jax | Converts TensorFlow functions to equivalent JAX Python functions. | 105 |
gordicaleksa/get-started-with-jax | A repository providing tutorials and resources to learn JAX, a popular alternative to PyTorch and TensorFlow for machine learning. | 661 |
jaxgaussianprocesses/gpjax | Provides a low-level interface to Gaussian process models in JAX for flexible extension and customisation | 461 |
expectationmax/sklearn-jax-kernels | A set of composable kernels for scikit-learn implemented in JAX to accelerate computation and gradient calculation. | 42 |
google/jaxopt | An open-source project providing hardware accelerated, batchable and differentiable optimizers in JAX for deep learning. | 933 |
information-fusion-lab-umass/nux | A library for building parametric models of complex distributions using normalizing flows and JAX. | 82 |
mpi4jax/mpi4jax | Enables JAX array communication over MPI without copying data | 445 |
google/neural-tangents | A high-level neural network API for defining and training complex hierarchical networks of finite or infinite width | 2,278 |
hips/autograd | Automatically computes derivatives of Python and NumPy code for optimization tasks | 7,017 |
chriswaites/jax-flows | A Python library implementing normalizing flows, a type of generative model used in machine learning. | 274 |
numpy/numpy | Provides support for large, multi-dimensional arrays and matrices, along with functions to manipulate them, as well as tools for integration with C/C++ code. | 28,087 |
differentiableuniverseinitiative/jax_cosmo | A Python library for differentiable cosmology using automatic differentiation on GPUs | 178 |
patrick-kidger/diffrax | Provides numerical differential equation solvers using autodifferentiable and GPU-capable JAX. | 1,442 |
dipolar-quantum-gases/jaxfit | A package for GPU/TPU accelerated nonlinear least-squares curve fitting using JAX | 51 |