jax
Numerical computing library
A library that provides high-performance numerical computing and machine learning capabilities.
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
31k stars
336 watching
3k forks
Language: Python
last commit: about 1 month ago
Linked from 6 awesome lists
jax
Related projects:
Repository | Description | Stars |
---|---|---|
dfm/extending-jax | This project provides infrastructure to interface custom C++ and CUDA code with the JAX library for scientific computing | 379 |
google-deepmind/tf2jax | Converts TensorFlow functions to equivalent JAX Python functions. | 109 |
gordicaleksa/get-started-with-jax | A repository providing tutorials and resources to learn JAX, a popular alternative to PyTorch and TensorFlow for machine learning. | 670 |
jaxgaussianprocesses/gpjax | Provides a low-level interface to Gaussian process models in JAX for flexible extension and customisation | 467 |
expectationmax/sklearn-jax-kernels | A set of composable kernels for scikit-learn implemented in JAX to accelerate computation and gradient calculation. | 42 |
google/jaxopt | An open-source project providing hardware accelerated, batchable and differentiable optimizers in JAX for deep learning. | 941 |
information-fusion-lab-umass/nux | A library for building parametric models of complex distributions using normalizing flows and JAX. | 82 |
mpi4jax/mpi4jax | Enables JAX array communication over MPI without copying data | 453 |
google/neural-tangents | A high-level neural network API for defining and training complex hierarchical networks of finite or infinite width | 2,291 |
hips/autograd | Automatically computes derivatives of Python and NumPy code for optimization tasks | 7,049 |
chriswaites/jax-flows | A Python library implementing normalizing flows, a type of generative model used in machine learning. | 275 |
numpy/numpy | A comprehensive library providing efficient numerical computation and data manipulation capabilities for Python-based scientific computing. | 28,350 |
differentiableuniverseinitiative/jax_cosmo | A Python library for differentiable cosmology using automatic differentiation on GPUs | 182 |
patrick-kidger/diffrax | Provides numerical differential equation solvers using autodifferentiable and GPU-capable JAX. | 1,480 |
dipolar-quantum-gases/jaxfit | A package for GPU/TPU accelerated nonlinear least-squares curve fitting using JAX | 53 |