Minimal setup to observe the Grokking phenomenon on an algorithmic task.
This is a minimal setup to observe Grokking (delayed generalization) on an algorithmic task.
The task is modular addition. The model is a simple 2-layer MLP that takes in two learnable embeddings of dimension hidden_dim=128
concatenated together. Each embedding representes an integer and the target is their sum modulo 53.
Run python train.py --grok
to see these training curves:
and python train.py
for a more "normal" run:
The only difference between these two runs is that weight decay is set to 5 (!) in the non-grokking (comprehension) run and 0.03 in the grokking run. For more details on the effect of hyperparameters on the grokking phenomenon, see this paper: Towards understanding grokking.
The --log
option will log the training curves locally. python plot.py
uses the logs in the log/
folder to make the plots you see above.
You'll need torch
, numpy
, matplotlib
, tqdm
, and potentially sklearn
if you want to use the --anim
option for plot.py
to animate the embeddings.