flax

Neural Network Library

Provides a flexible neural network library and ecosystem for JAX

Flax is a neural network library for JAX that is designed for flexibility.

GitHub

6k stars
87 watching
652 forks
Language: Jupyter Notebook
last commit: about 1 month ago
Linked from 4 awesome lists

jax

Backlinks from these awesome lists:

Related projects:

Repository Description Stars
google-deepmind/dm-haiku A JAX-based neural network library for building and optimizing neural networks 2,921
matthias-wright/flaxmodels Provides pre-trained deep learning models for the Jax/Flax ecosystem. 240
n2cholas/jax-resnet Provides implementations and checkpoints for various ResNet variants using JAX and Flax. 105
google/trax An end-to-end deep learning library with clear code and speed 8,114
darshandeshpande/jax-models Provides a collection of deep learning models and utilities in JAX/Flax for research purposes. 151
google/neural-tangents A high-level neural network API for defining and training complex hierarchical networks of finite or infinite width 2,291
google-deepmind/optax A gradient processing and optimization library designed to facilitate research and productivity in machine learning by providing building blocks for custom optimizers and gradient processing components. 1,730
gordicaleksa/get-started-with-jax A repository providing tutorials and resources to learn JAX, a popular alternative to PyTorch and TensorFlow for machine learning. 670
google-deepmind/jraph A lightweight library for working with graph neural networks in jax. 1,380
ikostrikov/jaxrl Provides JAX implementations of various reinforcement learning algorithms with continuous action spaces. 640
google/jaxopt An open-source project providing hardware accelerated, batchable and differentiable optimizers in JAX for deep learning. 941
young-geng/easylm A framework for training and serving large language models using JAX/Flax 2,428
google/dopamine A research framework for fast prototyping of reinforcement learning algorithms. 10,591
google/paxml A framework for configuring and running machine learning experiments on top of Jax. 461
patrick-kidger/equinox A JAX-based library for building and running neural networks with ease. 2,157