Skip to content

Commit 1aa2326

Browse files
authored
Various Fabric documentation updates (Lightning-AI#17236)
1 parent 0489f2e commit 1aa2326

File tree

5 files changed

+163
-38
lines changed

5 files changed

+163
-38
lines changed

docs/source-fabric/api/fabric_args.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ Configure the devices to run on. Can be of type:
7878
# int: run on two GPUs
7979
fabric = Fabric(devices=2, accelerator="gpu")
8080
81-
# list: run on GPUs 1, 4 (by bus ordering)
81+
# list: run on the 2nd (idx 1) and 5th (idx 4) GPUs (by bus ordering)
8282
fabric = Fabric(devices=[1, 4], accelerator="gpu")
8383
fabric = Fabric(devices="1, 4", accelerator="gpu") # equivalent
8484

docs/source-fabric/fundamentals/convert.rst

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,7 @@ All steps combined, this is how your code will change:
5454
.. code-block:: diff
5555
5656
import torch
57-
import torch.nn.functional as F
5857
from lightning.pytorch.demos import WikiText2, Transformer
59-
6058
+ import lightning as L
6159
6260
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -79,7 +77,7 @@ All steps combined, this is how your code will change:
7977
- input, target = input.to(device), target.to(device)
8078
optimizer.zero_grad()
8179
output = model(input, target)
82-
loss = F.nll_loss(output, target.view(-1))
80+
loss = torch.nn.functional.nll_loss(output, target.view(-1))
8381
- loss.backward()
8482
+ fabric.backward(loss)
8583
optimizer.step()

docs/source-fabric/index.rst

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@ Fabric is the fast and lightweight way to scale PyTorch models without boilerpla
1717
.. code-block:: diff
1818
1919
import torch
20-
import torch.nn.functional as F
2120
from lightning.pytorch.demos import WikiText2, Transformer
22-
2321
+ import lightning as L
2422
2523
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -42,7 +40,7 @@ Fabric is the fast and lightweight way to scale PyTorch models without boilerpla
4240
- input, target = input.to(device), target.to(device)
4341
optimizer.zero_grad()
4442
output = model(input, target)
45-
loss = F.nll_loss(output, target.view(-1))
43+
loss = torch.nn.functional.nll_loss(output, target.view(-1))
4644
- loss.backward()
4745
+ fabric.backward(loss)
4846
optimizer.step()

docs/source-pytorch/accelerators/gpu_basic.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ a comma separated list of GPU ids:
6666
Trainer(accelerator="gpu", devices="0, 1")
6767

6868
# To use all available GPUs put -1 or '-1'
69-
# equivalent to list(range(torch.cuda.device_count()))
69+
# equivalent to `list(range(torch.cuda.device_count())) and `"auto"`
7070
Trainer(accelerator="gpu", devices=-1)
7171

7272
The table below lists examples of possible input formats and how they are interpreted by Lightning.
@@ -80,11 +80,11 @@ The table below lists examples of possible input formats and how they are interp
8080
+------------------+-----------+---------------------+---------------------------------+
8181
| [0] | list | [0] | GPU 0 |
8282
+------------------+-----------+---------------------+---------------------------------+
83-
| [1, 3] | list | [1, 3] | GPUs 1 and 3 |
83+
| [1, 3] | list | [1, 3] | GPU index 1 and 3 (0-based) |
8484
+------------------+-----------+---------------------+---------------------------------+
8585
| "3" | str | [0, 1, 2] | first 3 GPUs |
8686
+------------------+-----------+---------------------+---------------------------------+
87-
| "1, 3" | str | [1, 3] | GPUs 1 and 3 |
87+
| "1, 3" | str | [1, 3] | GPU index 1 and 3 (0-based) |
8888
+------------------+-----------+---------------------+---------------------------------+
8989
| "-1" | str | [0, 1, 2, ...] | all available GPUs |
9090
+------------------+-----------+---------------------+---------------------------------+

src/lightning_fabric/README.md

Lines changed: 157 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -22,53 +22,182 @@ ______________________________________________________________________
2222

2323
</div>
2424

25-
## Maximum flexibility, minimum code changes
25+
# Lightning Fabric: Expert control.
2626

27-
With just a few code changes, run any PyTorch model on any distributed hardware, no boilerplate!
27+
Run on any device at any scale with expert-level control over PyTorch training loop and scaling strategy. You can even write your own Trainer.
2828

29-
- Easily switch from running on CPU to GPU (Apple Silicon, CUDA, …), TPU, multi-GPU or even multi-node training
30-
- Use state-of-the-art distributed training strategies (DDP, FSDP, DeepSpeed) and mixed precision out of the box
31-
- All the device logic boilerplate is handled for you
32-
- Designed with multi-billion parameter models in mind
33-
- Build your own custom Trainer using Fabric primitives for training checkpointing, logging, and more
29+
Fabric is designed for the most complex models like foundation model scaling, LLMs, diffusion, transformers, reinforcement learning, active learning. Of any size.
30+
31+
<table>
32+
<tr>
33+
<th>What to change</th>
34+
<th>Resulting Fabric Code (copy me!)</th>
35+
</tr>
36+
<tr>
37+
<td>
38+
<sub>
3439

3540
```diff
3641
+ import lightning as L
42+
import torch; import torchvision as tv
3743

38-
import torch
39-
import torch.nn as nn
40-
from torch.utils.data import DataLoader, Dataset
41-
42-
class PyTorchModel(nn.Module):
43-
...
44-
45-
class PyTorchDataset(Dataset):
46-
...
47-
48-
+ fabric = L.Fabric(accelerator="cuda", devices=8, strategy="ddp")
44+
+ fabric = L.Fabric()
4945
+ fabric.launch()
5046

51-
- device = "cuda" if torch.cuda.is_available() else "cpu
52-
model = PyTorchModel(...)
53-
optimizer = torch.optim.SGD(model.parameters())
47+
model = tv.models.resnet18()
48+
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
49+
- device = "cuda" if torch.cuda.is_available() else "cpu"
50+
- model.to(device)
5451
+ model, optimizer = fabric.setup(model, optimizer)
55-
dataloader = DataLoader(PyTorchDataset(...), ...)
52+
53+
dataset = tv.datasets.CIFAR10("data", download=True,
54+
train=True,
55+
transform=tv.transforms.ToTensor())
56+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8)
5657
+ dataloader = fabric.setup_dataloaders(dataloader)
57-
model.train()
5858

59+
model.train()
60+
num_epochs = 10
5961
for epoch in range(num_epochs):
6062
for batch in dataloader:
61-
input, target = batch
62-
- input, target = input.to(device), target.to(device)
63+
inputs, labels = batch
64+
- inputs, labels = inputs.to(device), labels.to(device)
6365
optimizer.zero_grad()
64-
output = model(input)
65-
loss = loss_fn(output, target)
66+
outputs = model(inputs)
67+
loss = torch.nn.functional.cross_entropy(outputs, labels)
6668
- loss.backward()
6769
+ fabric.backward(loss)
6870
optimizer.step()
69-
lr_scheduler.step()
7071
```
7172

73+
</sub>
74+
<td>
75+
<sub>
76+
77+
```Python
78+
import lightning as L
79+
import torch; import torchvision as tv
80+
81+
fabric = L.Fabric()
82+
fabric.launch()
83+
84+
model = tv.models.resnet18()
85+
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
86+
model, optimizer = fabric.setup(model, optimizer)
87+
88+
dataset = tv.datasets.CIFAR10("data", download=True,
89+
train=True,
90+
transform=tv.transforms.ToTensor())
91+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8)
92+
dataloader = fabric.setup_dataloaders(dataloader)
93+
94+
model.train()
95+
num_epochs = 10
96+
for epoch in range(num_epochs):
97+
for batch in dataloader:
98+
inputs, labels = batch
99+
optimizer.zero_grad()
100+
outputs = model(inputs)
101+
loss = torch.nn.functional.cross_entropy(outputs, labels)
102+
fabric.backward(loss)
103+
optimizer.step()
104+
```
105+
106+
</sub>
107+
</td>
108+
</tr>
109+
</table>
110+
111+
## Key features
112+
113+
<details>
114+
<summary>Easily switch from running on CPU to GPU (Apple Silicon, CUDA, …), TPU, multi-GPU or even multi-node training</summary>
115+
116+
```python
117+
# Use your available hardware
118+
# no code changes needed
119+
fabric = Fabric()
120+
121+
# Run on GPUs (CUDA or MPS)
122+
fabric = Fabric(accelerator="gpu")
123+
124+
# 8 GPUs
125+
fabric = Fabric(accelerator="gpu", devices=8)
126+
127+
# 256 GPUs, multi-node
128+
fabric = Fabric(accelerator="gpu", devices=8, num_nodes=32)
129+
130+
# Run on TPUs
131+
fabric = Fabric(accelerator="tpu")
132+
```
133+
134+
</details>
135+
136+
<details>
137+
<summary>Use state-of-the-art distributed training strategies (DDP, FSDP, DeepSpeed) and mixed precision out of the box</summary>
138+
139+
```python
140+
# Use state-of-the-art distributed training techniques
141+
fabric = Fabric(strategy="ddp")
142+
fabric = Fabric(strategy="deepspeed")
143+
fabric = Fabric(strategy="fsdp")
144+
145+
# Switch the precision
146+
fabric = Fabric(precision="16-mixed")
147+
fabric = Fabric(precision="64")
148+
```
149+
150+
</details>
151+
152+
<details>
153+
<summary>All the device logic boilerplate is handled for you</summary>
154+
155+
```diff
156+
# no more of this!
157+
- model.to(device)
158+
- batch.to(device)
159+
```
160+
161+
</details>
162+
163+
<details>
164+
<summary>Build your own custom Trainer using Fabric primitives for training checkpointing, logging, and more</summary>
165+
166+
```python
167+
import lightning as L
168+
169+
170+
class MyCustomTrainer:
171+
def __init__(self, accelerator="auto", strategy="auto", devices="auto", precision="32-true"):
172+
self.fabric = L.Fabric(accelerator=accelerator, strategy=strategy, devices=devices, precision=precision)
173+
174+
def fit(self, model, optimizer, dataloader, max_epochs):
175+
self.fabric.launch()
176+
177+
model, optimizer = self.fabric.setup(model, optimizer)
178+
dataloader = self.fabric.setup_dataloaders(dataloader)
179+
model.train()
180+
181+
for epoch in range(max_epochs):
182+
for batch in dataloader:
183+
input, target = batch
184+
optimizer.zero_grad()
185+
output = model(input)
186+
loss = loss_fn(output, target)
187+
self.fabric.backward(loss)
188+
optimizer.step()
189+
```
190+
191+
You can find a more extensive example in our [examples](../../examples/fabric/build_your_own_trainer)
192+
193+
</details>
194+
195+
______________________________________________________________________
196+
197+
<div align="center">
198+
<a href="https://lightning.ai/docs/fabric/stable/">Read the Lightning Fabric docs</a>
199+
</div>
200+
72201
______________________________________________________________________
73202

74203
# Getting started

0 commit comments

Comments
 (0)