Colab: https://colab.research.google.com/drive/1jJaAARwbsFeV5hZoffrNwFhc8i2I-7ji?usp=sharing
This repository provides both Flax (JAX) and PyTorch implementations of the DeepSeek-R1-Distill-Qwen-1.5B model. It includes:
-
Inference [QUICKSTART]:
inference.ipynb
: Contains a quickstart script to download and convert params from torch to flax, load model and perform text generation.
-
Flax Implementations:
model_flax.py
: The Flax implementation.
-
PyTorch Implementation:
model_torch.py
: A reference implementation in PyTorch.
-
Conversion Script:
torch_to_flax.py
: A utility to convert a PyTorch checkpoint (state dictionary) into Flax parameters.
16GB VRAM on the GPU + 64GB RAM (this can be swap)
Runs sharded on v2-8 TPU on Google Colab.