flax

Neural Network Library

A high-performance neural network library for JAX designed to facilitate flexibility and ease of use.

Flax is a neural network library for JAX that is designed for flexibility.

GitHub

6k stars
86 watching
648 forks
Language: Jupyter Notebook
last commit: 2 days ago
Linked from 4 awesome lists

jax

Backlinks from these awesome lists:

Related projects:

Repository Description Stars
google-deepmind/dm-haiku A JAX-based neural network library for building and optimizing neural networks 2,907
matthias-wright/flaxmodels Provides pre-trained deep learning models for the Jax/Flax ecosystem. 238
n2cholas/jax-resnet Provides implementations and checkpoints for various ResNet variants using JAX and Flax. 104
google/trax An end-to-end deep learning library with clear code and speed 8,102
darshandeshpande/jax-models Provides a collection of deep learning models and utilities in JAX/Flax for research purposes. 151
google/neural-tangents A high-level neural network API for defining and training complex hierarchical networks of finite or infinite width 2,278
google-deepmind/optax A gradient processing and optimization library designed to facilitate research and productivity in machine learning by providing building blocks for custom optimizers and gradient processing components. 1,697
gordicaleksa/get-started-with-jax A repository providing tutorials and resources to learn JAX, a popular alternative to PyTorch and TensorFlow for machine learning. 661
google-deepmind/jraph A lightweight library for working with graph neural networks in jax. 1,375
ikostrikov/jaxrl Provides JAX implementations of various reinforcement learning algorithms with continuous action spaces. 630
google/jaxopt An open-source project providing hardware accelerated, batchable and differentiable optimizers in JAX for deep learning. 933
young-geng/easylm A framework for training and serving large language models using JAX/Flax 2,409
google/dopamine A research framework for fast prototyping of reinforcement learning algorithms. 10,569
google/paxml A framework for configuring and running machine learning experiments on top of Jax. 457
patrick-kidger/equinox A JAX-based library for building and running neural networks with ease. 2,118