FlaxGPT is a simplistic Flax implementation of GPT (decoder-only transformer) model. The code is minimum in a single notebook and therefore good for hacking and educational purposes.
Open the main flax_gpt colab and start hacking.
- Implement GPT model using flax
- Load / Convert LLaMA2-7B checkpoint for prediction
- Implement K-V cache in prediction
- Pretraining (example)
- Finetuning (example)
- LoRA finetuning
- Quantization
- Distributed Training (TPUs)
Here are some tutorials of how I implemented GPT from scratch.
- GPT From Scratch Using Flax explains how I created FlaxGPT, step by step.
- GPT From Scrach Using Jax if you prefer a more "hardcore" implementation using only the low level jax, please check this out.