Skip to content

Commit e6bf6a2

Browse files
committed
added testing w/ deepspeed
1 parent d953f44 commit e6bf6a2

File tree

3 files changed

+244
-0
lines changed

3 files changed

+244
-0
lines changed

zero_offload/__init__.py

Whitespace-only changes.

zero_offload/single_gpu_train.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import argparse
2+
import deepspeed
3+
import torch
4+
from torchvision.transforms import ToTensor
5+
from torchvision.datasets import CIFAR10
6+
from zero_offload.vit_pytorch import ViT
7+
from time import perf_counter
8+
9+
10+
def add_argument():
11+
"""
12+
https://www.deepspeed.ai/tutorials/cifar-10/
13+
"""
14+
parser=argparse.ArgumentParser(description='CIFAR')
15+
16+
# data
17+
# cuda
18+
parser.add_argument('--with_cuda', default=False, action='store_true',
19+
help='use CPU in case there\'s no GPU support')
20+
parser.add_argument('--use_ema', default=False, action='store_true',
21+
help='whether use exponential moving average')
22+
23+
# train
24+
parser.add_argument('-b', '--batch_size', default=512, type=int,
25+
help='mini-batch size (default: 32)')
26+
parser.add_argument('-e', '--epochs', default=30, type=int,
27+
help='number of total epochs (default: 30)')
28+
parser.add_argument('--local_rank', type=int, default=-1,
29+
help='local rank passed from distributed launcher')
30+
31+
# Include DeepSpeed configuration arguments
32+
parser = deepspeed.add_config_arguments(parser)
33+
34+
return parser.parse_args()
35+
36+
37+
def main():
38+
args = add_argument()
39+
dataset = CIFAR10('.', download=True, transform=ToTensor())
40+
trainloader = torch.utils.data.DataLoader(dataset,
41+
batch_size=args.batch_size,
42+
shuffle=True,
43+
num_workers=8)
44+
huge_model = ViT(
45+
image_size=32,
46+
patch_size=4,
47+
num_classes=10,
48+
dim=512,
49+
depth=8,
50+
heads=8,
51+
mlp_dim=2048,
52+
dropout=0.1,
53+
emb_dropout=0.1
54+
)
55+
lr = 0.001
56+
warmup_steps = 1000
57+
remain_steps = (args.epochs * len(trainloader) - warmup_steps)
58+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
59+
criterion = torch.nn.CrossEntropyLoss()
60+
optimizer = torch.optim.Adam(
61+
huge_model.parameters(),
62+
lr=lr,
63+
betas=(0.8, 0.999),
64+
eps=1e-8,
65+
weight_decay=3e-7)
66+
torch.optim.lr_scheduler.LambdaLR(
67+
optimizer,
68+
lambda epoch: (epoch + 1) / warmup_steps * lr if epoch < warmup_steps else (epoch - warmup_steps) * lr / remain_steps)
69+
model_engine, _, trainloader_ds, _ = deepspeed.initialize(
70+
args=args,
71+
model=huge_model,
72+
model_parameters=huge_model.parameters(),
73+
training_data=dataset)
74+
75+
# training w/ DeepSpeed
76+
start_time = perf_counter()
77+
for data in trainloader_ds:
78+
inputs = data[0].to(model_engine.device)
79+
labels = data[1].to(model_engine.device)
80+
81+
outputs = model_engine(inputs)
82+
loss = criterion(outputs, labels)
83+
84+
model_engine.backward(loss)
85+
model_engine.step()
86+
ds_time = (perf_counter() - start_time) / 60
87+
print('###################################################################')
88+
print(f'Training CIFAR10 using DeepSpeed used {ds_time:.3f} minutes')
89+
90+
# regular training
91+
model = huge_model.to(device)
92+
start_time = perf_counter()
93+
for data in trainloader:
94+
inputs = data[0].to(device)
95+
labels = data[1].to(device)
96+
97+
outputs = model(inputs)
98+
loss = criterion(outputs, labels)
99+
100+
optimizer.zero_grad()
101+
loss.backward()
102+
optimizer.step()
103+
no_ds_time = (perf_counter() - start_time) / 60
104+
print('###################################################################')
105+
print(f'Training CIFAR10 without using DeepSpeed used {no_ds_time:.3f} minutes')
106+
print('###################################################################')
107+
print(f'DeepSpeed accelerated training by {no_ds_time - ds_time:.3f} minutes')
108+
109+
110+
if __name__ == '__main__':
111+
main()

zero_offload/vit_pytorch.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
"""
2+
https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit_pytorch.py
3+
"""
4+
5+
import torch
6+
import torch.nn.functional as F
7+
from einops import rearrange, repeat
8+
from torch import nn
9+
10+
MIN_NUM_PATCHES = 16
11+
12+
class Residual(nn.Module):
13+
def __init__(self, fn):
14+
super().__init__()
15+
self.fn = fn
16+
def forward(self, x, **kwargs):
17+
return self.fn(x, **kwargs) + x
18+
19+
class PreNorm(nn.Module):
20+
def __init__(self, dim, fn):
21+
super().__init__()
22+
self.norm = nn.LayerNorm(dim)
23+
self.fn = fn
24+
def forward(self, x, **kwargs):
25+
return self.fn(self.norm(x), **kwargs)
26+
27+
class FeedForward(nn.Module):
28+
def __init__(self, dim, hidden_dim, dropout = 0.):
29+
super().__init__()
30+
self.net = nn.Sequential(
31+
nn.Linear(dim, hidden_dim),
32+
nn.GELU(),
33+
nn.Dropout(dropout),
34+
nn.Linear(hidden_dim, dim),
35+
nn.Dropout(dropout)
36+
)
37+
def forward(self, x):
38+
return self.net(x)
39+
40+
class Attention(nn.Module):
41+
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
42+
super().__init__()
43+
inner_dim = dim_head * heads
44+
self.heads = heads
45+
self.scale = dim_head ** -0.5
46+
47+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
48+
self.to_out = nn.Sequential(
49+
nn.Linear(inner_dim, dim),
50+
nn.Dropout(dropout)
51+
)
52+
53+
def forward(self, x, mask = None):
54+
b, n, _, h = *x.shape, self.heads
55+
qkv = self.to_qkv(x).chunk(3, dim = -1)
56+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
57+
58+
dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
59+
mask_value = -torch.finfo(dots.dtype).max
60+
61+
if mask is not None:
62+
mask = F.pad(mask.flatten(1), (1, 0), value = True)
63+
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
64+
mask = mask[:, None, :] * mask[:, :, None]
65+
dots.masked_fill_(~mask, mask_value)
66+
del mask
67+
68+
attn = dots.softmax(dim=-1)
69+
70+
out = torch.einsum('bhij,bhjd->bhid', attn, v)
71+
out = rearrange(out, 'b h n d -> b n (h d)')
72+
out = self.to_out(out)
73+
return out
74+
75+
class Transformer(nn.Module):
76+
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout):
77+
super().__init__()
78+
self.layers = nn.ModuleList([])
79+
for _ in range(depth):
80+
self.layers.append(nn.ModuleList([
81+
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
82+
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
83+
]))
84+
def forward(self, x, mask = None):
85+
for attn, ff in self.layers:
86+
x = attn(x, mask = mask)
87+
x = ff(x)
88+
return x
89+
90+
class ViT(nn.Module):
91+
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
92+
super().__init__()
93+
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
94+
num_patches = (image_size // patch_size) ** 2
95+
patch_dim = channels * patch_size ** 2
96+
assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size'
97+
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
98+
99+
self.patch_size = patch_size
100+
101+
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
102+
self.patch_to_embedding = nn.Linear(patch_dim, dim)
103+
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
104+
self.dropout = nn.Dropout(emb_dropout)
105+
106+
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
107+
108+
self.pool = pool
109+
self.to_latent = nn.Identity()
110+
111+
self.mlp_head = nn.Sequential(
112+
nn.LayerNorm(dim),
113+
nn.Linear(dim, num_classes)
114+
)
115+
116+
def forward(self, img, mask = None):
117+
p = self.patch_size
118+
119+
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
120+
x = self.patch_to_embedding(x)
121+
b, n, _ = x.shape
122+
123+
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
124+
x = torch.cat((cls_tokens, x), dim=1)
125+
x += self.pos_embedding[:, :(n + 1)]
126+
x = self.dropout(x)
127+
128+
x = self.transformer(x, mask)
129+
130+
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
131+
132+
x = self.to_latent(x)
133+
return self.mlp_head(x)

0 commit comments

Comments
 (0)