Skip to content

Commit

Permalink
added testing w/ deepspeed
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrewzh112 committed Jan 27, 2021
1 parent d953f44 commit e6bf6a2
Show file tree
Hide file tree
Showing 3 changed files with 244 additions and 0 deletions.
Empty file added zero_offload/__init__.py
Empty file.
111 changes: 111 additions & 0 deletions zero_offload/single_gpu_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import argparse
import deepspeed
import torch
from torchvision.transforms import ToTensor
from torchvision.datasets import CIFAR10
from zero_offload.vit_pytorch import ViT
from time import perf_counter


def add_argument():
"""
https://www.deepspeed.ai/tutorials/cifar-10/
"""
parser=argparse.ArgumentParser(description='CIFAR')

# data
# cuda
parser.add_argument('--with_cuda', default=False, action='store_true',
help='use CPU in case there\'s no GPU support')
parser.add_argument('--use_ema', default=False, action='store_true',
help='whether use exponential moving average')

# train
parser.add_argument('-b', '--batch_size', default=512, type=int,
help='mini-batch size (default: 32)')
parser.add_argument('-e', '--epochs', default=30, type=int,
help='number of total epochs (default: 30)')
parser.add_argument('--local_rank', type=int, default=-1,
help='local rank passed from distributed launcher')

# Include DeepSpeed configuration arguments
parser = deepspeed.add_config_arguments(parser)

return parser.parse_args()


def main():
args = add_argument()
dataset = CIFAR10('.', download=True, transform=ToTensor())
trainloader = torch.utils.data.DataLoader(dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=8)
huge_model = ViT(
image_size=32,
patch_size=4,
num_classes=10,
dim=512,
depth=8,
heads=8,
mlp_dim=2048,
dropout=0.1,
emb_dropout=0.1
)
lr = 0.001
warmup_steps = 1000
remain_steps = (args.epochs * len(trainloader) - warmup_steps)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
huge_model.parameters(),
lr=lr,
betas=(0.8, 0.999),
eps=1e-8,
weight_decay=3e-7)
torch.optim.lr_scheduler.LambdaLR(
optimizer,
lambda epoch: (epoch + 1) / warmup_steps * lr if epoch < warmup_steps else (epoch - warmup_steps) * lr / remain_steps)
model_engine, _, trainloader_ds, _ = deepspeed.initialize(
args=args,
model=huge_model,
model_parameters=huge_model.parameters(),
training_data=dataset)

# training w/ DeepSpeed
start_time = perf_counter()
for data in trainloader_ds:
inputs = data[0].to(model_engine.device)
labels = data[1].to(model_engine.device)

outputs = model_engine(inputs)
loss = criterion(outputs, labels)

model_engine.backward(loss)
model_engine.step()
ds_time = (perf_counter() - start_time) / 60
print('###################################################################')
print(f'Training CIFAR10 using DeepSpeed used {ds_time:.3f} minutes')

# regular training
model = huge_model.to(device)
start_time = perf_counter()
for data in trainloader:
inputs = data[0].to(device)
labels = data[1].to(device)

outputs = model(inputs)
loss = criterion(outputs, labels)

optimizer.zero_grad()
loss.backward()
optimizer.step()
no_ds_time = (perf_counter() - start_time) / 60
print('###################################################################')
print(f'Training CIFAR10 without using DeepSpeed used {no_ds_time:.3f} minutes')
print('###################################################################')
print(f'DeepSpeed accelerated training by {no_ds_time - ds_time:.3f} minutes')


if __name__ == '__main__':
main()
133 changes: 133 additions & 0 deletions zero_offload/vit_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
"""
https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit_pytorch.py
"""

import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import nn

MIN_NUM_PATCHES = 16

class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x

class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)

class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5

self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)

def forward(self, x, mask = None):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
mask_value = -torch.finfo(dots.dtype).max

if mask is not None:
mask = F.pad(mask.flatten(1), (1, 0), value = True)
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
mask = mask[:, None, :] * mask[:, :, None]
dots.masked_fill_(~mask, mask_value)
del mask

attn = dots.softmax(dim=-1)

out = torch.einsum('bhij,bhjd->bhid', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out

class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
]))
def forward(self, x, mask = None):
for attn, ff in self.layers:
x = attn(x, mask = mask)
x = ff(x)
return x

class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size'
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

self.patch_size = patch_size

self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.patch_to_embedding = nn.Linear(patch_dim, dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)

self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

self.pool = pool
self.to_latent = nn.Identity()

self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)

def forward(self, img, mask = None):
p = self.patch_size

x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
x = self.patch_to_embedding(x)
b, n, _ = x.shape

cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)

x = self.transformer(x, mask)

x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

x = self.to_latent(x)
return self.mlp_head(x)

0 comments on commit e6bf6a2

Please sign in to comment.