get-started-with-JAX

JAX tutorial series

A repository providing tutorials and resources to learn JAX, a popular alternative to PyTorch and TensorFlow for machine learning.

The purpose of this repo is to make it easy to get started with JAX, Flax, and Haiku. It contains my "Machine Learning with JAX" series of tutorials (YouTube videos and Jupyter Notebooks) as well as the content I found useful while learning about the JAX ecosystem.

GitHub

661 stars
9 watching
98 forks
Language: Jupyter Notebook
last commit: 12 months ago
Linked from 1 awesome list

deep-learningflaxhaikujaxjupyterlaxlearn-jaxmachine-learningnumpyoptaxpythontutorialxla

Backlinks from these awesome lists:

Related projects:

Repository Description Stars
jaxgaussianprocesses/gpjax Provides a low-level interface to Gaussian process models in JAX for flexible extension and customisation 461
ikostrikov/jaxrl Provides JAX implementations of various reinforcement learning algorithms with continuous action spaces. 630
darshandeshpande/jax-models Provides a collection of deep learning models and utilities in JAX/Flax for research purposes. 151
google-deepmind/chex A set of utilities for writing reliable JAX code 788
matthias-wright/flaxmodels Provides pre-trained deep learning models for the Jax/Flax ecosystem. 238
dfm/extending-jax This project provides infrastructure to interface custom C++ and CUDA code with the JAX library for scientific computing 378
hakky54/java-tutorials A repository containing various Java tutorials on topics such as Elasticsearch, gRPC, security, and serialization. 36
expectationmax/sklearn-jax-kernels A set of composable kernels for scikit-learn implemented in JAX to accelerate computation and gradient calculation. 42
jeremiecoullon/jax-tqdm Library that enhances JAX with support for dynamic progress bars in scans and loops 96
google-deepmind/jraph A lightweight library for working with graph neural networks in jax. 1,375
vict0rsch/deep_learning A collection of tutorials and resources on implementing deep learning models using Python libraries such as Keras and Lasagne. 426
logicalclocks/hopsworks-tutorials A collection of tutorials and notebooks providing hands-on experience with the Hopsworks Platform for machine learning development and data analysis. 251
nvidia/jax-toolbox A collection of optimized JAX libraries and examples for simplified development on NVIDIA GPUs 245
owickstrom/gi-gtk-declarative A Haskell package providing a declarative programming framework for building GTK+ applications. 288
kmheckel/spyx A JAX-based library for training and utilizing spiking neural networks 101