jax
Numerical computing library
A library that provides high-performance numerical computing and machine learning capabilities.
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
31k stars
336 watching
3k forks
Language: Python
last commit: 3 months ago
Linked from 6 awesome lists
jax
Related projects:
Repository | Description | Stars |
---|---|---|
| This project provides infrastructure to interface custom C++ and CUDA code with the JAX library for scientific computing | 379 |
| Converts TensorFlow functions to equivalent JAX Python functions. | 109 |
| A repository providing tutorials and resources to learn JAX, a popular alternative to PyTorch and TensorFlow for machine learning. | 670 |
| Provides a low-level interface to Gaussian process models in JAX for flexible extension and customisation | 467 |
| A set of composable kernels for scikit-learn implemented in JAX to accelerate computation and gradient calculation. | 42 |
| An open-source project providing hardware accelerated, batchable and differentiable optimizers in JAX for deep learning. | 941 |
| A library for building parametric models of complex distributions using normalizing flows and JAX. | 82 |
| Enables JAX array communication over MPI without copying data | 453 |
| A high-level neural network API for defining and training complex hierarchical networks of finite or infinite width | 2,291 |
| Automatically computes derivatives of Python and NumPy code for optimization tasks | 7,049 |
| A Python library implementing normalizing flows, a type of generative model used in machine learning. | 275 |
| A comprehensive library providing efficient numerical computation and data manipulation capabilities for Python-based scientific computing. | 28,350 |
| A Python library for differentiable cosmology using automatic differentiation on GPUs | 182 |
| Provides numerical differential equation solvers using autodifferentiable and GPU-capable JAX. | 1,480 |
| A package for GPU/TPU accelerated nonlinear least-squares curve fitting using JAX | 53 |