Skip to content

Flax (Jax) implementation of DeepSeek-R1-Distill-Qwen-1.5B with weights ported from Hugging Face.

License

Notifications You must be signed in to change notification settings

J-Rosser-UK/Torch2Jax-DeepSeek-R1-Distill-Qwen-1.5B

Repository files navigation

A jax logo style image of a whale.

Torch2Jax-DeepSeek-R1-Distill-Qwen-1.5B

Flax (JAX) implementation of DeepSeek-R1-Distill-Qwen-1.5B with weights ported from Hugging Face.

contributors last update forks stars open issues license

Overview

Colab: https://colab.research.google.com/drive/1jJaAARwbsFeV5hZoffrNwFhc8i2I-7ji?usp=sharing

This repository provides both Flax (JAX) and PyTorch implementations of the DeepSeek-R1-Distill-Qwen-1.5B model. It includes:

  • Inference [QUICKSTART]:

    • inference.ipynb: Contains a quickstart script to download and convert params from torch to flax, load model and perform text generation.
  • Flax Implementations:

    • model_flax.py: The Flax implementation.
  • PyTorch Implementation:

    • model_torch.py: A reference implementation in PyTorch.
  • Conversion Script:

    • torch_to_flax.py: A utility to convert a PyTorch checkpoint (state dictionary) into Flax parameters.

System Requirements

Single GPU

16GB VRAM on the GPU + 64GB RAM (this can be swap)

Multi-Device

Runs sharded on v2-8 TPU on Google Colab.

About

Flax (Jax) implementation of DeepSeek-R1-Distill-Qwen-1.5B with weights ported from Hugging Face.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published