-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d953f44
commit e6bf6a2
Showing
3 changed files
with
244 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
import argparse | ||
import deepspeed | ||
import torch | ||
from torchvision.transforms import ToTensor | ||
from torchvision.datasets import CIFAR10 | ||
from zero_offload.vit_pytorch import ViT | ||
from time import perf_counter | ||
|
||
|
||
def add_argument(): | ||
""" | ||
https://www.deepspeed.ai/tutorials/cifar-10/ | ||
""" | ||
parser=argparse.ArgumentParser(description='CIFAR') | ||
|
||
# data | ||
# cuda | ||
parser.add_argument('--with_cuda', default=False, action='store_true', | ||
help='use CPU in case there\'s no GPU support') | ||
parser.add_argument('--use_ema', default=False, action='store_true', | ||
help='whether use exponential moving average') | ||
|
||
# train | ||
parser.add_argument('-b', '--batch_size', default=512, type=int, | ||
help='mini-batch size (default: 32)') | ||
parser.add_argument('-e', '--epochs', default=30, type=int, | ||
help='number of total epochs (default: 30)') | ||
parser.add_argument('--local_rank', type=int, default=-1, | ||
help='local rank passed from distributed launcher') | ||
|
||
# Include DeepSpeed configuration arguments | ||
parser = deepspeed.add_config_arguments(parser) | ||
|
||
return parser.parse_args() | ||
|
||
|
||
def main(): | ||
args = add_argument() | ||
dataset = CIFAR10('.', download=True, transform=ToTensor()) | ||
trainloader = torch.utils.data.DataLoader(dataset, | ||
batch_size=args.batch_size, | ||
shuffle=True, | ||
num_workers=8) | ||
huge_model = ViT( | ||
image_size=32, | ||
patch_size=4, | ||
num_classes=10, | ||
dim=512, | ||
depth=8, | ||
heads=8, | ||
mlp_dim=2048, | ||
dropout=0.1, | ||
emb_dropout=0.1 | ||
) | ||
lr = 0.001 | ||
warmup_steps = 1000 | ||
remain_steps = (args.epochs * len(trainloader) - warmup_steps) | ||
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | ||
criterion = torch.nn.CrossEntropyLoss() | ||
optimizer = torch.optim.Adam( | ||
huge_model.parameters(), | ||
lr=lr, | ||
betas=(0.8, 0.999), | ||
eps=1e-8, | ||
weight_decay=3e-7) | ||
torch.optim.lr_scheduler.LambdaLR( | ||
optimizer, | ||
lambda epoch: (epoch + 1) / warmup_steps * lr if epoch < warmup_steps else (epoch - warmup_steps) * lr / remain_steps) | ||
model_engine, _, trainloader_ds, _ = deepspeed.initialize( | ||
args=args, | ||
model=huge_model, | ||
model_parameters=huge_model.parameters(), | ||
training_data=dataset) | ||
|
||
# training w/ DeepSpeed | ||
start_time = perf_counter() | ||
for data in trainloader_ds: | ||
inputs = data[0].to(model_engine.device) | ||
labels = data[1].to(model_engine.device) | ||
|
||
outputs = model_engine(inputs) | ||
loss = criterion(outputs, labels) | ||
|
||
model_engine.backward(loss) | ||
model_engine.step() | ||
ds_time = (perf_counter() - start_time) / 60 | ||
print('###################################################################') | ||
print(f'Training CIFAR10 using DeepSpeed used {ds_time:.3f} minutes') | ||
|
||
# regular training | ||
model = huge_model.to(device) | ||
start_time = perf_counter() | ||
for data in trainloader: | ||
inputs = data[0].to(device) | ||
labels = data[1].to(device) | ||
|
||
outputs = model(inputs) | ||
loss = criterion(outputs, labels) | ||
|
||
optimizer.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
no_ds_time = (perf_counter() - start_time) / 60 | ||
print('###################################################################') | ||
print(f'Training CIFAR10 without using DeepSpeed used {no_ds_time:.3f} minutes') | ||
print('###################################################################') | ||
print(f'DeepSpeed accelerated training by {no_ds_time - ds_time:.3f} minutes') | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
""" | ||
https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit_pytorch.py | ||
""" | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from einops import rearrange, repeat | ||
from torch import nn | ||
|
||
MIN_NUM_PATCHES = 16 | ||
|
||
class Residual(nn.Module): | ||
def __init__(self, fn): | ||
super().__init__() | ||
self.fn = fn | ||
def forward(self, x, **kwargs): | ||
return self.fn(x, **kwargs) + x | ||
|
||
class PreNorm(nn.Module): | ||
def __init__(self, dim, fn): | ||
super().__init__() | ||
self.norm = nn.LayerNorm(dim) | ||
self.fn = fn | ||
def forward(self, x, **kwargs): | ||
return self.fn(self.norm(x), **kwargs) | ||
|
||
class FeedForward(nn.Module): | ||
def __init__(self, dim, hidden_dim, dropout = 0.): | ||
super().__init__() | ||
self.net = nn.Sequential( | ||
nn.Linear(dim, hidden_dim), | ||
nn.GELU(), | ||
nn.Dropout(dropout), | ||
nn.Linear(hidden_dim, dim), | ||
nn.Dropout(dropout) | ||
) | ||
def forward(self, x): | ||
return self.net(x) | ||
|
||
class Attention(nn.Module): | ||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): | ||
super().__init__() | ||
inner_dim = dim_head * heads | ||
self.heads = heads | ||
self.scale = dim_head ** -0.5 | ||
|
||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) | ||
self.to_out = nn.Sequential( | ||
nn.Linear(inner_dim, dim), | ||
nn.Dropout(dropout) | ||
) | ||
|
||
def forward(self, x, mask = None): | ||
b, n, _, h = *x.shape, self.heads | ||
qkv = self.to_qkv(x).chunk(3, dim = -1) | ||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) | ||
|
||
dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale | ||
mask_value = -torch.finfo(dots.dtype).max | ||
|
||
if mask is not None: | ||
mask = F.pad(mask.flatten(1), (1, 0), value = True) | ||
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' | ||
mask = mask[:, None, :] * mask[:, :, None] | ||
dots.masked_fill_(~mask, mask_value) | ||
del mask | ||
|
||
attn = dots.softmax(dim=-1) | ||
|
||
out = torch.einsum('bhij,bhjd->bhid', attn, v) | ||
out = rearrange(out, 'b h n d -> b n (h d)') | ||
out = self.to_out(out) | ||
return out | ||
|
||
class Transformer(nn.Module): | ||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout): | ||
super().__init__() | ||
self.layers = nn.ModuleList([]) | ||
for _ in range(depth): | ||
self.layers.append(nn.ModuleList([ | ||
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))), | ||
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))) | ||
])) | ||
def forward(self, x, mask = None): | ||
for attn, ff in self.layers: | ||
x = attn(x, mask = mask) | ||
x = ff(x) | ||
return x | ||
|
||
class ViT(nn.Module): | ||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): | ||
super().__init__() | ||
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' | ||
num_patches = (image_size // patch_size) ** 2 | ||
patch_dim = channels * patch_size ** 2 | ||
assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size' | ||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' | ||
|
||
self.patch_size = patch_size | ||
|
||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) | ||
self.patch_to_embedding = nn.Linear(patch_dim, dim) | ||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) | ||
self.dropout = nn.Dropout(emb_dropout) | ||
|
||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) | ||
|
||
self.pool = pool | ||
self.to_latent = nn.Identity() | ||
|
||
self.mlp_head = nn.Sequential( | ||
nn.LayerNorm(dim), | ||
nn.Linear(dim, num_classes) | ||
) | ||
|
||
def forward(self, img, mask = None): | ||
p = self.patch_size | ||
|
||
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p) | ||
x = self.patch_to_embedding(x) | ||
b, n, _ = x.shape | ||
|
||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) | ||
x = torch.cat((cls_tokens, x), dim=1) | ||
x += self.pos_embedding[:, :(n + 1)] | ||
x = self.dropout(x) | ||
|
||
x = self.transformer(x, mask) | ||
|
||
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] | ||
|
||
x = self.to_latent(x) | ||
return self.mlp_head(x) |