get-started-with-JAX

JAX tutorial series

A repository providing tutorials and resources to learn JAX, a popular alternative to PyTorch and TensorFlow for machine learning.

The purpose of this repo is to make it easy to get started with JAX, Flax, and Haiku. It contains my "Machine Learning with JAX" series of tutorials (YouTube videos and Jupyter Notebooks) as well as the content I found useful while learning about the JAX ecosystem.

GitHub

670 stars
9 watching
97 forks
Language: Jupyter Notebook
last commit: about 1 year ago
Linked from 1 awesome list

deep-learningflaxhaikujaxjupyterlaxlearn-jaxmachine-learningnumpyoptaxpythontutorialxla

Backlinks from these awesome lists:

Related projects:

Repository Description Stars
jaxgaussianprocesses/gpjax Provides a low-level interface to Gaussian process models in JAX for flexible extension and customisation 467
ikostrikov/jaxrl Provides JAX implementations of various reinforcement learning algorithms with continuous action spaces. 640
darshandeshpande/jax-models Provides a collection of deep learning models and utilities in JAX/Flax for research purposes. 151
google-deepmind/chex A set of utilities for writing reliable JAX code 797
matthias-wright/flaxmodels Provides pre-trained deep learning models for the Jax/Flax ecosystem. 240
dfm/extending-jax This project provides infrastructure to interface custom C++ and CUDA code with the JAX library for scientific computing 379
hakky54/java-tutorials A repository containing various Java tutorials on topics such as Elasticsearch, gRPC, security, and serialization. 35
expectationmax/sklearn-jax-kernels A set of composable kernels for scikit-learn implemented in JAX to accelerate computation and gradient calculation. 42
jeremiecoullon/jax-tqdm Library that enhances JAX with support for dynamic progress bars in scans and loops 100
google-deepmind/jraph A lightweight library for working with graph neural networks in jax. 1,380
vict0rsch/deep_learning A collection of tutorials and resources on implementing deep learning models using Python libraries such as Keras and Lasagne. 426
logicalclocks/hopsworks-tutorials A collection of tutorials and notebooks providing hands-on experience with the Hopsworks Platform for machine learning development and data analysis. 261
nvidia/jax-toolbox Provides optimized tools and infrastructure for JAX development on NVIDIA GPUs 268
owickstrom/gi-gtk-declarative A Haskell package providing a declarative programming framework for building GTK+ applications. 288
kmheckel/spyx A JAX-based library for training and utilizing spiking neural networks 104