jax
Numerical computing library
A library that provides high-performance numerical computing and machine learning capabilities.
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
31k stars
336 watching
3k forks
Language: Python
last commit: 11 months ago
Linked from 6 awesome lists
jax
Related projects:
| Repository | Description | Stars |
|---|---|---|
| | This project provides infrastructure to interface custom C++ and CUDA code with the JAX library for scientific computing | 379 |
| | Converts TensorFlow functions to equivalent JAX Python functions. | 109 |
| | A repository providing tutorials and resources to learn JAX, a popular alternative to PyTorch and TensorFlow for machine learning. | 670 |
| | Provides a low-level interface to Gaussian process models in JAX for flexible extension and customisation | 467 |
| | A set of composable kernels for scikit-learn implemented in JAX to accelerate computation and gradient calculation. | 42 |
| | An open-source project providing hardware accelerated, batchable and differentiable optimizers in JAX for deep learning. | 941 |
| | A library for building parametric models of complex distributions using normalizing flows and JAX. | 82 |
| | Enables JAX array communication over MPI without copying data | 453 |
| | A high-level neural network API for defining and training complex hierarchical networks of finite or infinite width | 2,291 |
| | Automatically computes derivatives of Python and NumPy code for optimization tasks | 7,049 |
| | A Python library implementing normalizing flows, a type of generative model used in machine learning. | 275 |
| | A comprehensive library providing efficient numerical computation and data manipulation capabilities for Python-based scientific computing. | 28,350 |
| | A Python library for differentiable cosmology using automatic differentiation on GPUs | 182 |
| | Provides numerical differential equation solvers using autodifferentiable and GPU-capable JAX. | 1,480 |
| | A package for GPU/TPU accelerated nonlinear least-squares curve fitting using JAX | 53 |