jax-models

Deep Learning Models

Provides a collection of deep learning models and utilities in JAX/Flax for research purposes.

Unofficial JAX implementations of deep learning research papers

GitHub

151 stars
5 watching
8 forks
Language: Python
last commit: over 2 years ago
Linked from 1 awesome list

artificial-intelligencecomputer-visionconvolutional-neural-networksdeep-learningflaxjaxmachine-learningresearch-paper-implementationtransformers

Backlinks from these awesome lists:

Related projects:

Repository Description Stars
matthias-wright/flaxmodels Provides pre-trained deep learning models for the Jax/Flax ecosystem. 238
google-deepmind/jraph A lightweight library for working with graph neural networks in jax. 1,375
yuyang-huang/keras-inception-resnet-v2 Represents an implementation of the Inception-ResNet v2 deep learning model in Keras. 180
vict0rsch/deep_learning A collection of tutorials and resources on implementing deep learning models using Python libraries such as Keras and Lasagne. 426
deepseek-ai/deepseek-vl A multimodal AI model that enables real-world vision-language understanding applications 2,077
kuleshov/deep-learning-models Implementations of various deep learning algorithms in Python using Theano and Lasagne. 24
n2cholas/jax-resnet Provides implementations and checkpoints for various ResNet variants using JAX and Flax. 104
ikostrikov/jaxrl Provides JAX implementations of various reinforcement learning algorithms with continuous action spaces. 630
google/jaxopt An open-source project providing hardware accelerated, batchable and differentiable optimizers in JAX for deep learning. 933
nitishsrivastava/deepnet A collection of GPU-accelerated deep learning algorithms implemented in Python 895
scicloj/scicloj.ml.clj-djl Provides pre-trained machine learning models for natural language processing tasks using Clojure and the clj-djl framework. 0
google-deepmind/jaxline Provides a Python-based framework for building distributed JAX training and evaluation experiments 152
balavenkatesh3322/nlp-pretrained-model A collection of pre-trained natural language processing models 170
jaxgaussianprocesses/gpjax Provides a low-level interface to Gaussian process models in JAX for flexible extension and customisation 461
l0sg/relational-rnn-pytorch An implementation of DeepMind's Relational Recurrent Neural Networks (Santoro et al. 2018) in PyTorch for word language modeling 244