Skip to content

Commit 25a9ca6

Browse files
committed
lightning example
1 parent 857a9b0 commit 25a9ca6

File tree

2 files changed

+43
-31
lines changed

2 files changed

+43
-31
lines changed

docs/source/examples/lightning.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Pytorch Lightning
2+
3+
```{eval-rst}
4+
.. literalinclude:: ./lightning.py
5+
```
Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
from dataclasses import dataclass
1+
import os
22
from pathlib import Path
33

4-
import deepspeed
4+
import lightning as L
55
import torch
6-
76
from datasets import load_dataset
7+
88
from torch import nn
99
from torch.utils.data import Dataset
1010
from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -42,47 +42,54 @@ def __getitem__(self, idx):
4242
}
4343

4444

45-
@dataclass
46-
class DSPArgs:
47-
deepspeed_config: str
48-
# train_batch_size: int
49-
# batch_size: int
45+
class GPT2LightningWrapper(L.LightningModule):
46+
def __init__(self):
47+
super().__init__()
48+
self.model = AutoModelForCausalLM.from_pretrained("gpt2")
49+
50+
def training_step(self, batch, batch_idx):
51+
device_batch = {k: v.to(self.model.device) for k, v in batch.items()}
52+
loss = self.model(**device_batch).loss
53+
self.log("train_loss", loss)
54+
return loss
55+
56+
def configure_optimizers(self):
57+
optimizer = torch.optim.Adam(self.parameters(), lr=1e-5)
58+
return optimizer
5059

5160

5261
def train():
53-
model = AutoModelForCausalLM.from_pretrained("gpt2")
54-
# optimizer = torch.optim.Adam(model.parameters())
62+
lightning_model = GPT2LightningWrapper()
63+
5564
wikitext_train = load_dataset("Salesforce/wikitext", "wikitext-2-v1", split="train")
5665
train_dataset = GPT2CausalLMDataset(text_dataset=wikitext_train)
57-
58-
loader = torch.utils.data.DataLoader(train_dataset, batch_size=8)
59-
60-
model_engine, optimizer, _, _ = deepspeed.initialize(
61-
args=DSPArgs(deepspeed_config="dsp_config.json"),
62-
model=model,
63-
model_parameters=model.parameters(),
66+
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8)
67+
68+
trainer = L.Trainer(
69+
accelerator="gpu",
70+
limit_train_batches=10,
71+
max_epochs=1,
72+
devices=2,
73+
num_nodes=1,
74+
strategy="ddp",
6475
)
6576

66-
model.train()
67-
for batch_idx, batch in enumerate(loader):
68-
if batch_idx == 10:
69-
break
70-
print(f"Step {batch_idx}")
71-
72-
device_batch = {k: v.to(model.device) for k, v in batch.items()}
73-
74-
model.zero_grad()
77+
trainer.fit(model=lightning_model, train_dataloaders=train_loader)
7578

76-
loss = model_engine(**device_batch).loss
77-
model_engine.backward(loss)
78-
79-
model_engine.step()
79+
if int(os.environ["RANK"]) == 0:
80+
return trainer.model.model
81+
return None
8082

8183

8284
if __name__ == "__main__":
85+
# hack to prevent lightning from recognizing SLURM environment...
86+
os.environ["SLURM_JOB_NAME"] = "bash"
8387
Path("output").mkdir(exist_ok=True)
8488
results = torchrunx.launch(
8589
func=train,
8690
hostnames=["localhost"],
87-
workers_per_host=1,
91+
workers_per_host=2,
8892
)
93+
94+
trained_model: nn.Module = results.rank(0)
95+
torch.save(trained_model.state_dict(), "output/model.pth")

0 commit comments

Comments
 (0)