optax
Gradient optimizer library
A gradient processing and optimization library designed to facilitate research and productivity in machine learning by providing building blocks for custom optimizers and gradient processing components.
Optax is a gradient processing and optimization library for JAX.
2k stars
35 watching
196 forks
Language: Python
last commit: about 1 month ago
Linked from 2 awesome lists
machine-learningoptimization
Related projects:
Repository | Description | Stars |
---|---|---|
google/jaxopt | An open-source project providing hardware accelerated, batchable and differentiable optimizers in JAX for deep learning. | 941 |
google-deepmind/jraph | A lightweight library for working with graph neural networks in jax. | 1,380 |
google-deepmind/kfac-jax | Library providing an implementation of the K-FAC optimizer and curvature estimator for second-order optimization in neural networks. | 252 |
google-deepmind/dm_pix | An image processing library built on top of JAX to provide optimized and parallelized functions for machine learning research. | 395 |
google-deepmind/einshape | A unified reshaping library for JAX and other frameworks. | 100 |
matthias-wright/flaxmodels | Provides pre-trained deep learning models for the Jax/Flax ecosystem. | 240 |
google-deepmind/distrax | A library of probability distributions and bijectors with a focus on readability, extensibility, and compatibility with existing frameworks. | 538 |
deependersingla/deep_portfolio | An algorithm that optimizes portfolio allocation using Reinforcement Learning and Supervised learning. | 168 |
google-research/sputnik | A library of optimized GPU kernels for sparse matrix operations used in deep learning. | 248 |
darshandeshpande/jax-models | Provides a collection of deep learning models and utilities in JAX/Flax for research purposes. | 151 |
100/solid | A comprehensive framework for solving optimization problems without gradient calculations. | 575 |
google-deepmind/chex | A set of utilities for writing reliable JAX code | 797 |
locuslab/optnet | A PyTorch module that adds differentiable optimization as a layer to neural networks | 517 |
google-deepmind/jaxline | Provides a Python-based framework for building distributed JAX training and evaluation experiments | 153 |
neuralmagic/sparseml | Enables the creation of smaller neural network models through efficient pruning and quantization techniques | 2,083 |