Skip to content

Commit 213b6a0

Browse files
committed
Use torchinfo
1 parent b17cce8 commit 213b6a0

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

pytorch/gpt_pytorch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import os
55
from utils import load_encoder_hparams_and_params
66
from model import GPT2
7-
import torchsummaryX
87
from tqdm import tqdm
8+
from torchinfo import summary
99

1010

1111
if __name__ == "__main__":
@@ -31,7 +31,7 @@
3131
print("input_ids:", input_ids)
3232

3333
model = GPT2(params, hparams, drop_p=0.1)
34-
torchsummaryX.summary(model, torch.ones(1, len(input_ids), dtype=torch.long))
34+
summary(model, input_size=(1, len(input_ids)), dtypes=[torch.long])
3535

3636
for _ in tqdm(range(args.n_tokens_to_generate), "generating"):
3737
logits = model(torch.tensor(input_ids).unsqueeze(0))

0 commit comments

Comments
 (0)