@@ -22,53 +22,182 @@ ______________________________________________________________________
22
22
23
23
</div >
24
24
25
- ## Maximum flexibility, minimum code changes
25
+ # Lightning Fabric: Expert control.
26
26
27
- With just a few code changes, run any PyTorch model on any distributed hardware, no boilerplate!
27
+ Run on any device at any scale with expert-level control over PyTorch training loop and scaling strategy. You can even write your own Trainer.
28
28
29
- - Easily switch from running on CPU to GPU (Apple Silicon, CUDA, …), TPU, multi-GPU or even multi-node training
30
- - Use state-of-the-art distributed training strategies (DDP, FSDP, DeepSpeed) and mixed precision out of the box
31
- - All the device logic boilerplate is handled for you
32
- - Designed with multi-billion parameter models in mind
33
- - Build your own custom Trainer using Fabric primitives for training checkpointing, logging, and more
29
+ Fabric is designed for the most complex models like foundation model scaling, LLMs, diffusion, transformers, reinforcement learning, active learning. Of any size.
30
+
31
+ <table >
32
+ <tr >
33
+ <th >What to change</th >
34
+ <th >Resulting Fabric Code (copy me!)</th >
35
+ </tr >
36
+ <tr >
37
+ <td >
38
+ <sub >
34
39
35
40
``` diff
36
41
+ import lightning as L
42
+ import torch; import torchvision as tv
37
43
38
- import torch
39
- import torch.nn as nn
40
- from torch.utils.data import DataLoader, Dataset
41
-
42
- class PyTorchModel(nn.Module):
43
- ...
44
-
45
- class PyTorchDataset(Dataset):
46
- ...
47
-
48
- + fabric = L.Fabric(accelerator="cuda", devices=8, strategy="ddp")
44
+ + fabric = L.Fabric()
49
45
+ fabric.launch()
50
46
51
- - device = "cuda" if torch.cuda.is_available() else "cpu
52
- model = PyTorchModel(...)
53
- optimizer = torch.optim.SGD(model.parameters())
47
+ model = tv.models.resnet18()
48
+ optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
49
+ - device = "cuda" if torch.cuda.is_available() else "cpu"
50
+ - model.to(device)
54
51
+ model, optimizer = fabric.setup(model, optimizer)
55
- dataloader = DataLoader(PyTorchDataset(...), ...)
52
+
53
+ dataset = tv.datasets.CIFAR10("data", download=True,
54
+ train=True,
55
+ transform=tv.transforms.ToTensor())
56
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=8)
56
57
+ dataloader = fabric.setup_dataloaders(dataloader)
57
- model.train()
58
58
59
+ model.train()
60
+ num_epochs = 10
59
61
for epoch in range(num_epochs):
60
62
for batch in dataloader:
61
- input, target = batch
62
- - input, target = input .to(device), target .to(device)
63
+ inputs, labels = batch
64
+ - inputs, labels = inputs .to(device), labels .to(device)
63
65
optimizer.zero_grad()
64
- output = model(input )
65
- loss = loss_fn(output, target )
66
+ outputs = model(inputs )
67
+ loss = torch.nn.functional.cross_entropy(outputs, labels )
66
68
- loss.backward()
67
69
+ fabric.backward(loss)
68
70
optimizer.step()
69
- lr_scheduler.step()
70
71
```
71
72
73
+ </sub >
74
+ <td >
75
+ <sub >
76
+
77
+ ``` Python
78
+ import lightning as L
79
+ import torch; import torchvision as tv
80
+
81
+ fabric = L.Fabric()
82
+ fabric.launch()
83
+
84
+ model = tv.models.resnet18()
85
+ optimizer = torch.optim.SGD(model.parameters(), lr = 0.001 )
86
+ model, optimizer = fabric.setup(model, optimizer)
87
+
88
+ dataset = tv.datasets.CIFAR10(" data" , download = True ,
89
+ train = True ,
90
+ transform = tv.transforms.ToTensor())
91
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size = 8 )
92
+ dataloader = fabric.setup_dataloaders(dataloader)
93
+
94
+ model.train()
95
+ num_epochs = 10
96
+ for epoch in range (num_epochs):
97
+ for batch in dataloader:
98
+ inputs, labels = batch
99
+ optimizer.zero_grad()
100
+ outputs = model(inputs)
101
+ loss = torch.nn.functional.cross_entropy(outputs, labels)
102
+ fabric.backward(loss)
103
+ optimizer.step()
104
+ ```
105
+
106
+ </sub >
107
+ </td >
108
+ </tr >
109
+ </table >
110
+
111
+ ## Key features
112
+
113
+ <details >
114
+ <summary >Easily switch from running on CPU to GPU (Apple Silicon, CUDA, …), TPU, multi-GPU or even multi-node training</summary >
115
+
116
+ ``` python
117
+ # Use your available hardware
118
+ # no code changes needed
119
+ fabric = Fabric()
120
+
121
+ # Run on GPUs (CUDA or MPS)
122
+ fabric = Fabric(accelerator = " gpu" )
123
+
124
+ # 8 GPUs
125
+ fabric = Fabric(accelerator = " gpu" , devices = 8 )
126
+
127
+ # 256 GPUs, multi-node
128
+ fabric = Fabric(accelerator = " gpu" , devices = 8 , num_nodes = 32 )
129
+
130
+ # Run on TPUs
131
+ fabric = Fabric(accelerator = " tpu" )
132
+ ```
133
+
134
+ </details >
135
+
136
+ <details >
137
+ <summary >Use state-of-the-art distributed training strategies (DDP, FSDP, DeepSpeed) and mixed precision out of the box</summary >
138
+
139
+ ``` python
140
+ # Use state-of-the-art distributed training techniques
141
+ fabric = Fabric(strategy = " ddp" )
142
+ fabric = Fabric(strategy = " deepspeed" )
143
+ fabric = Fabric(strategy = " fsdp" )
144
+
145
+ # Switch the precision
146
+ fabric = Fabric(precision = " 16-mixed" )
147
+ fabric = Fabric(precision = " 64" )
148
+ ```
149
+
150
+ </details >
151
+
152
+ <details >
153
+ <summary >All the device logic boilerplate is handled for you</summary >
154
+
155
+ ``` diff
156
+ # no more of this!
157
+ - model.to(device)
158
+ - batch.to(device)
159
+ ```
160
+
161
+ </details >
162
+
163
+ <details >
164
+ <summary >Build your own custom Trainer using Fabric primitives for training checkpointing, logging, and more</summary >
165
+
166
+ ``` python
167
+ import lightning as L
168
+
169
+
170
+ class MyCustomTrainer :
171
+ def __init__ (self , accelerator = " auto" , strategy = " auto" , devices = " auto" , precision = " 32-true" ):
172
+ self .fabric = L.Fabric(accelerator = accelerator, strategy = strategy, devices = devices, precision = precision)
173
+
174
+ def fit (self , model , optimizer , dataloader , max_epochs ):
175
+ self .fabric.launch()
176
+
177
+ model, optimizer = self .fabric.setup(model, optimizer)
178
+ dataloader = self .fabric.setup_dataloaders(dataloader)
179
+ model.train()
180
+
181
+ for epoch in range (max_epochs):
182
+ for batch in dataloader:
183
+ input , target = batch
184
+ optimizer.zero_grad()
185
+ output = model(input )
186
+ loss = loss_fn(output, target)
187
+ self .fabric.backward(loss)
188
+ optimizer.step()
189
+ ```
190
+
191
+ You can find a more extensive example in our [ examples] ( ../../examples/fabric/build_your_own_trainer )
192
+
193
+ </details >
194
+
195
+ ______________________________________________________________________
196
+
197
+ <div align =" center " >
198
+ <a href="https://lightning.ai/docs/fabric/stable/">Read the Lightning Fabric docs</a>
199
+ </div >
200
+
72
201
______________________________________________________________________
73
202
74
203
# Getting started
0 commit comments