|
1 |
| -from dataclasses import dataclass |
| 1 | +import os |
2 | 2 | from pathlib import Path
|
3 | 3 |
|
4 |
| -import deepspeed |
| 4 | +import lightning as L |
5 | 5 | import torch
|
6 |
| - |
7 | 6 | from datasets import load_dataset
|
| 7 | + |
8 | 8 | from torch import nn
|
9 | 9 | from torch.utils.data import Dataset
|
10 | 10 | from transformers import AutoModelForCausalLM, AutoTokenizer
|
@@ -42,47 +42,54 @@ def __getitem__(self, idx):
|
42 | 42 | }
|
43 | 43 |
|
44 | 44 |
|
45 |
| -@dataclass |
46 |
| -class DSPArgs: |
47 |
| - deepspeed_config: str |
48 |
| - # train_batch_size: int |
49 |
| - # batch_size: int |
| 45 | +class GPT2LightningWrapper(L.LightningModule): |
| 46 | + def __init__(self): |
| 47 | + super().__init__() |
| 48 | + self.model = AutoModelForCausalLM.from_pretrained("gpt2") |
| 49 | + |
| 50 | + def training_step(self, batch, batch_idx): |
| 51 | + device_batch = {k: v.to(self.model.device) for k, v in batch.items()} |
| 52 | + loss = self.model(**device_batch).loss |
| 53 | + self.log("train_loss", loss) |
| 54 | + return loss |
| 55 | + |
| 56 | + def configure_optimizers(self): |
| 57 | + optimizer = torch.optim.Adam(self.parameters(), lr=1e-5) |
| 58 | + return optimizer |
50 | 59 |
|
51 | 60 |
|
52 | 61 | def train():
|
53 |
| - model = AutoModelForCausalLM.from_pretrained("gpt2") |
54 |
| - # optimizer = torch.optim.Adam(model.parameters()) |
| 62 | + lightning_model = GPT2LightningWrapper() |
| 63 | + |
55 | 64 | wikitext_train = load_dataset("Salesforce/wikitext", "wikitext-2-v1", split="train")
|
56 | 65 | train_dataset = GPT2CausalLMDataset(text_dataset=wikitext_train)
|
57 |
| - |
58 |
| - loader = torch.utils.data.DataLoader(train_dataset, batch_size=8) |
59 |
| - |
60 |
| - model_engine, optimizer, _, _ = deepspeed.initialize( |
61 |
| - args=DSPArgs(deepspeed_config="dsp_config.json"), |
62 |
| - model=model, |
63 |
| - model_parameters=model.parameters(), |
| 66 | + train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8) |
| 67 | + |
| 68 | + trainer = L.Trainer( |
| 69 | + accelerator="gpu", |
| 70 | + limit_train_batches=10, |
| 71 | + max_epochs=1, |
| 72 | + devices=2, |
| 73 | + num_nodes=1, |
| 74 | + strategy="ddp", |
64 | 75 | )
|
65 | 76 |
|
66 |
| - model.train() |
67 |
| - for batch_idx, batch in enumerate(loader): |
68 |
| - if batch_idx == 10: |
69 |
| - break |
70 |
| - print(f"Step {batch_idx}") |
71 |
| - |
72 |
| - device_batch = {k: v.to(model.device) for k, v in batch.items()} |
73 |
| - |
74 |
| - model.zero_grad() |
| 77 | + trainer.fit(model=lightning_model, train_dataloaders=train_loader) |
75 | 78 |
|
76 |
| - loss = model_engine(**device_batch).loss |
77 |
| - model_engine.backward(loss) |
78 |
| - |
79 |
| - model_engine.step() |
| 79 | + if int(os.environ["RANK"]) == 0: |
| 80 | + return trainer.model.model |
| 81 | + return None |
80 | 82 |
|
81 | 83 |
|
82 | 84 | if __name__ == "__main__":
|
| 85 | + # hack to prevent lightning from recognizing SLURM environment... |
| 86 | + os.environ["SLURM_JOB_NAME"] = "bash" |
83 | 87 | Path("output").mkdir(exist_ok=True)
|
84 | 88 | results = torchrunx.launch(
|
85 | 89 | func=train,
|
86 | 90 | hostnames=["localhost"],
|
87 |
| - workers_per_host=1, |
| 91 | + workers_per_host=2, |
88 | 92 | )
|
| 93 | + |
| 94 | + trained_model: nn.Module = results.rank(0) |
| 95 | + torch.save(trained_model.state_dict(), "output/model.pth") |
0 commit comments