🐛 Bug
As per title, the following command takes double the time for the first iteration at commit 8b542cf :
python thunder/benchmarks/benchmark_litgpt.py --model_name Llama-2-7b-hf --compile thunder_inductor_cat --checkpoint_activations False --low_precision_mode none --micro_batch_size 1 --global_batch_size 64 --use_sdpa False --block_size 4096 --max_iters 10 --warmup_iters 5
compared to before #2502 :
# Before
iter 0: loss 0.1650, iter time: 105627.25ms, t: 4096
# After
iter 0: loss 0.1650, iter time: 215099.65ms, t: 4096
Tested on the latest container on B200
cc @rdspring1 @kshitij12345 @crcrpar