Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added CLIP module and redesigned tokenizer apis #81

Merged
merged 82 commits into from
Aug 29, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
82 commits
Select commit Hold shift + click to select a range
83fceae
saved workd
Anhforth Jun 29, 2022
f113f7c
saved workd
Anhforth Jun 29, 2022
708ce15
saved work on 6.29
Anhforth Jun 30, 2022
0d1b079
transformed tokenizer: progressing
Anhforth Jul 1, 2022
2763b5d
Opt 30b (#16)
920232796 Jul 1, 2022
3e52907
fix bert tokenizer issue (#18)
Anhforth Jul 1, 2022
deb2612
reconstruct the tokenizer structure
ZhaodongYan1 Jul 3, 2022
c2c6e9d
tested the new tokenizer
Anhforth Jul 4, 2022
fc2b5d8
removed some redundant codes and added sp model
Anhforth Jul 4, 2022
7da1757
updated the tokenizer
ZhaodongYan1 Jul 4, 2022
7c8c0b1
saved work
Anhforth Jul 5, 2022
3a0c8cb
Opt 66b (#19)
920232796 Jul 6, 2022
265d35a
saved work on 7.6
Anhforth Jul 6, 2022
4f8d715
updated release version
Anhforth Jul 6, 2022
efc1310
fix tokenizer issue
Anhforth Jul 6, 2022
59531e7
temp save
Anhforth Jul 6, 2022
3b6c16a
tokenizer test passed
Anhforth Jul 6, 2022
a7ff8f3
fixed some errors
Anhforth Jul 7, 2022
f4ff1a8
test of tokenizer transform
Anhforth Jul 7, 2022
811d9e9
fixed conflicts
Anhforth Jul 7, 2022
1406d89
fixed error
Anhforth Jul 7, 2022
b30eefa
add encode_plus
Anhforth Jul 8, 2022
9b81869
fix bug multi_gpu_training
920232796 Jul 8, 2022
7ad38a0
Merge pull request #21 from baai-open-internal/fix_multi_gpu_training
Anhforth Jul 8, 2022
72ffd6a
changed the version
Anhforth Jul 8, 2022
e6f89a6
fix_validation_bug (#24)
920232796 Jul 11, 2022
29ea850
updated the version
Anhforth Jul 11, 2022
4c68936
updated
Anhforth Jul 15, 2022
4834f23
modified encoder_plus
Anhforth Jul 15, 2022
8d44329
add vit and examples
920232796 Jul 15, 2022
81c438d
vit and examples
920232796 Jul 15, 2022
da24628
Update base_model.py
marscrazy Jul 15, 2022
aff728b
Update vit.py
marscrazy Jul 15, 2022
e5a0ddb
modify readme.md
920232796 Jul 15, 2022
fe56b8b
modify readme.md
920232796 Jul 15, 2022
fc6c32e
delete annotating code
920232796 Jul 15, 2022
cd45e5c
Vit xzh (#25)
920232796 Jul 15, 2022
5448084
updated
Anhforth Jul 17, 2022
eb555fc
updated
Anhforth Jul 17, 2022
9649aa4
performing tests on examples
Anhforth Jul 17, 2022
67c1288
finished example testing
Anhforth Jul 18, 2022
faee281
Merge branch 'develop' into vit_xzh
BAAI-OpenPlatform Jul 19, 2022
06f0b69
Merge pull request #28 from baai-open-internal/vit_xzh
BAAI-OpenPlatform Jul 19, 2022
deaa120
Merge pull request #27 from baai-open-internal/develop
marscrazy Jul 20, 2022
9558a47
env trainer
920232796 Jul 20, 2022
c35d4b6
Merge pull request #29 from baai-open-internal/env_args
marscrazy Jul 20, 2022
437caa4
vit-checkpoint-activations
920232796 Jul 21, 2022
dc6fc3d
vit-checkpoint-activations
920232796 Jul 21, 2022
c1cec9f
Merge pull request #33 from baai-open-internal/vit-checkpointing-acti…
marscrazy Jul 21, 2022
d74cf92
update
jongjyh Jul 25, 2022
044bc80
Merge pull request #34 from baai-open-internal/fix_eval_loss
marscrazy Jul 25, 2022
d85f8af
merged the master
Anhforth Jul 26, 2022
1b5ecc6
inference and train
wchh-2000 Jul 29, 2022
1fe6d3e
fix bug bert model
xuanricheng Aug 5, 2022
0c243d6
add autoloader and example training data
wchh-2000 Aug 15, 2022
2c28a7d
updated seq2seq
shunxing1234 Aug 16, 2022
e03247e
update
wchh-2000 Aug 16, 2022
4a4b003
Merge pull request #52 from baai-open-internal/add_clip
marscrazy Aug 17, 2022
ce5fd31
Merge branch 'master' into transform_tokenizer
Anhforth Aug 18, 2022
8353cd3
Update train.py
marscrazy Aug 18, 2022
5d5e135
Delete tst_superglue.py
marscrazy Aug 18, 2022
4c6ba56
updated according to comments
BAAI-OpenPlatform Aug 19, 2022
6076287
Merge pull request #50 from baai-open-internal/bert_model
BAAI-OpenPlatform Aug 19, 2022
c11e232
merged the clip tokenizer
BAAI-OpenPlatform Aug 22, 2022
6e135ef
merged clip tokenizer
BAAI-OpenPlatform Aug 23, 2022
fd06e4d
Update inference_clip.py
marscrazy Aug 25, 2022
b61b708
Update auto_loader.py
marscrazy Aug 25, 2022
25b659b
Update glm_10b_en_tokenizer.py
marscrazy Aug 25, 2022
8cffa38
Merge pull request #20 from baai-open-internal/transform_tokenizer
marscrazy Aug 25, 2022
9117f78
swinv1v2
920232796 Aug 25, 2022
f3186d9
Merge pull request #58 from baai-open-internal/swinv1v2_checkpoint_ac…
marscrazy Aug 25, 2022
4bd211d
updated the version
Anhforth Aug 25, 2022
6ef4190
updated the requirement packages list
Anhforth Aug 25, 2022
036e337
fixed some issues
BAAI-OpenPlatform Aug 26, 2022
edfd518
fixed some issues
BAAI-OpenPlatform Aug 26, 2022
497d709
tried to fix the data directory not found error
BAAI-OpenPlatform Aug 26, 2022
1ac43c0
fixed issues in running glm_seq2seq
BAAI-OpenPlatform Aug 26, 2022
351fba7
Update test_glm_seq2seq.py
marscrazy Aug 26, 2022
35b5d9a
Merge pull request #59 from baai-open-internal/fix_issues
marscrazy Aug 26, 2022
d71ee8d
merged upstream
Anhforth Aug 26, 2022
e3836aa
Update setup.py
marscrazy Aug 26, 2022
619398b
Merge branch 'develop' of github.com:FlagAI-Open/FlagAI into develop
Anhforth Aug 26, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
vit-checkpoint-activations
Signed-off-by: zhaohu xing <920232796@qq.com>
  • Loading branch information
920232796 committed Jul 21, 2022
commit 437caa4e7816b461e3167f2a6981ad3f17bf40d8
10 changes: 10 additions & 0 deletions config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"img_size": 224,
"patch_size": 16,
"in_chans": 3,
"embed_dim": 768,
"depth": 12,
"num_heads": 12,
"num_classes": 100,
"checkpoint_activations": false
}
1 change: 1 addition & 0 deletions examples/vit_cifar100/train_DDP.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
save_interval=1000,
num_checkpoints=1,
hostfile="./hostfile",
training_script="train_DDP.py"
)

def build_cifar():
Expand Down
1 change: 1 addition & 0 deletions examples/vit_cifar100/train_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
save_interval=1000,
num_checkpoints=1,
hostfile="./hostfile",
training_script="train_deepspeed.py"
)

def build_cifar():
Expand Down
70 changes: 70 additions & 0 deletions flagai/model/vision/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@

import os
if os.getenv('ENV_TYPE') == 'deepspeed':
from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
else:
from torch.utils.checkpoint import checkpoint
import torch
from itertools import chain

def checkpoint_seq(
functions,
x,
every=1,
flatten=False,
skip_last=False,
):
r"""A helper function for checkpointing sequential models.
Sequential models execute a list of modules/functions in order
(sequentially). Therefore, we can divide such a sequence into segments
and checkpoint each segment. All segments except run in :func:`torch.no_grad`
manner, i.e., not storing the intermediate activations. The inputs of each
checkpointed segment will be saved for re-running the segment in the backward pass.
See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works.
.. warning::
Checkpointing currently only supports :func:`torch.autograd.backward`
and only if its `inputs` argument is not passed. :func:`torch.autograd.grad`
is not supported.
.. warning:
At least one of the inputs needs to have :code:`requires_grad=True` if
grads are needed for model inputs, otherwise the checkpointed part of the
model won't have gradients.
Args:
functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially.
x: A Tensor that is input to :attr:`functions`
every: checkpoint every-n functions (default: 1)
flatten (bool): flatten nn.Sequential of nn.Sequentials
skip_last (bool): skip checkpointing the last function in the sequence if True
preserve_rng_state (bool, optional, default=True): Omit stashing and restoring
the RNG state during each checkpoint.
Returns:
Output of running :attr:`functions` sequentially on :attr:`*inputs`
Example:
>>> model = nn.Sequential(...)
>>> input_var = checkpoint_seq(model, input_var, every=2)
"""
def run_function(start, end, functions):
def forward(_x):
for j in range(start, end + 1):
_x = functions[j](_x)
return _x
return forward

if isinstance(functions, torch.nn.Sequential):
functions = functions.children()
if flatten:
functions = chain.from_iterable(functions)
if not isinstance(functions, (tuple, list)):
functions = tuple(functions)

num_checkpointed = len(functions)
if skip_last:
num_checkpointed -= 1
end = -1
for start in range(0, num_checkpointed, every):
end = min(start + every - 1, num_checkpointed - 1)
x = checkpoint(run_function(start, end, functions), x)
if skip_last:
return run_function(end + 1, len(functions) - 1, functions)(x)
return x

53 changes: 27 additions & 26 deletions flagai/model/vision/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from flagai.model.vision.layers.drop import DropPath
from flagai.model.vision.layers.weight_init import trunc_normal_, lecun_normal_
from flagai.model.base_model import BaseModel
from flagai.model.vision.helpers import checkpoint_seq

class VitConfig:
def __init__(self,
Expand All @@ -53,7 +54,7 @@ def __init__(self,
attn_drop_rate=0.,
drop_path_rate=0.,
weight_init='',
checkpoint_activations=None):
checkpoint_activations=False):
pass
self.img_size=img_size
self.patch_size=patch_size
Expand All @@ -74,7 +75,6 @@ def __init__(self,
self.weight_init=weight_init
self.checkpoint_activations = checkpoint_activations


def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, include_root=False) -> nn.Module:
if not depth_first and include_root:
fn(module=module, name=name)
Expand Down Expand Up @@ -206,42 +206,42 @@ def __init__(
block_fn=Block
vit_config = VitConfig(**config)
vit_config.num_classes = num_classes
config = vit_config
# config = vit_config

assert config.global_pool in ('', 'avg', 'token')
assert config.class_token or config.global_pool != 'token'
use_fc_norm = config.global_pool == 'avg' if config.fc_norm is None else config.fc_norm
assert vit_config.global_pool in ('', 'avg', 'token')
assert vit_config.class_token or vit_config.global_pool != 'token'
use_fc_norm = vit_config.global_pool == 'avg' if vit_config.fc_norm is None else vit_config.fc_norm
norm_layer = partial(nn.LayerNorm, eps=1e-6)
act_layer = nn.GELU

self.num_classes = num_classes
self.global_pool = config.global_pool
self.num_features = self.embed_dim = config.embed_dim # num_features for consistency with other models
self.num_tokens = 1 if config.class_token else 0
self.grad_checkpointing = False
self.global_pool = vit_config.global_pool
self.num_features = self.embed_dim = vit_config.embed_dim # num_features for consistency with other models
self.num_tokens = 1 if vit_config.class_token else 0
self.grad_checkpointing = vit_config.checkpoint_activations

self.patch_embed = embed_layer(
img_size=config.img_size, patch_size=config.patch_size, in_chans=config.in_chans, embed_dim=config.embed_dim)
img_size=vit_config.img_size, patch_size=vit_config.patch_size, in_chans=vit_config.in_chans, embed_dim=vit_config.embed_dim)
num_patches = self.patch_embed.num_patches

self.cls_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if self.num_tokens > 0 else None
self.pos_embed = nn.Parameter(torch.randn(1, num_patches + self.num_tokens, config.embed_dim) * .02)
self.pos_drop = nn.Dropout(p=config.drop_rate)
self.cls_token = nn.Parameter(torch.zeros(1, 1, vit_config.embed_dim)) if self.num_tokens > 0 else None
self.pos_embed = nn.Parameter(torch.randn(1, num_patches + self.num_tokens, vit_config.embed_dim) * .02)
self.pos_drop = nn.Dropout(p=vit_config.drop_rate)

dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.depth)] # stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, vit_config.drop_path_rate, vit_config.depth)] # stochastic depth decay rule
self.blocks = nn.Sequential(*[
block_fn(
dim=config.embed_dim, num_heads=config.num_heads, mlp_ratio=config.mlp_ratio, qkv_bias=config.qkv_bias, init_values=config.init_values,
drop=config.drop_rate, attn_drop=config.attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
for i in range(config.depth)])
self.norm = norm_layer(config.embed_dim) if not use_fc_norm else nn.Identity()
dim=vit_config.embed_dim, num_heads=vit_config.num_heads, mlp_ratio=vit_config.mlp_ratio, qkv_bias=vit_config.qkv_bias, init_values=vit_config.init_values,
drop=vit_config.drop_rate, attn_drop=vit_config.attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
for i in range(vit_config.depth)])
self.norm = norm_layer(vit_config.embed_dim) if not use_fc_norm else nn.Identity()

# Classifier Head
self.fc_norm = norm_layer(config.embed_dim) if use_fc_norm else nn.Identity()
self.fc_norm = norm_layer(vit_config.embed_dim) if use_fc_norm else nn.Identity()
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

if config.weight_init != 'skip':
self.init_weights(config.weight_init)
if vit_config.weight_init != 'skip':
self.init_weights(vit_config.weight_init)

def init_weights(self, mode=''):
assert mode in ('jax', 'jax_nlhb', 'moco', '')
Expand Down Expand Up @@ -290,10 +290,11 @@ def forward_features(self, x):
if self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = self.pos_drop(x + self.pos_embed)
# if self.grad_checkpointing and not torch.jit.is_scripting():
# x = checkpoint_seq(self.blocks, x)
# else:
x = self.blocks(x)

if self.config["checkpoint_activations"]:
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
x = self.norm(x)
return x

Expand Down
5 changes: 3 additions & 2 deletions flagai/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,8 @@ def train(self,
train_dataset=None,
valid_dataset=None,
metric_methods=[],
collate_fn=None):
collate_fn=None,
find_unused_parameters=True):
"""Training Loops"""
"""
Trainer is a simple but unifed training and eval loop for PyTorch/Deepspeed/Megatron-LM.
Expand Down Expand Up @@ -416,7 +417,7 @@ def train(self,
model.to(torch.device('cuda', self.local_rank))
model = DDP(model,
device_ids=[self.local_rank],
find_unused_parameters=True)
find_unused_parameters=find_unused_parameters)

elif self.env_type == 'pytorch':
model.to(self.pytorch_device)
Expand Down
90 changes: 90 additions & 0 deletions train_deepspeed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import torch
from torchvision import transforms
from torchvision.datasets import CIFAR100
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
from flagai.trainer import Trainer
from flagai.auto_model.auto_loader import AutoLoader

lr = 2e-5
n_epochs = 50
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

env_type = "pytorchDDP"
trainer = Trainer(
env_type=env_type,
experiment_name="vit-cifar100-deepspeed",
batch_size=128,
num_gpus=2,
fp16=True,
gradient_accumulation_steps=1,
lr=lr,
weight_decay=1e-5,
epochs=n_epochs,
log_interval=10,
eval_interval=100,
load_dir=None,
pytorch_device=device,
save_dir="checkpoints_vit_cifar100_deepspeed",
save_interval=1000,
num_checkpoints=1,
hostfile="./hostfile",
deepspeed_config="./deepspeed.json",
training_script="train_deepspeed.py",
checkpoint_activations=True,
)

def build_cifar():
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.Resize(224),
transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

train_dataset = CIFAR100(root="./data/cifar100", train=True, download=True, transform=transform_train)
test_dataset = CIFAR100(root="./data/cifar100", train=False, download=True, transform=transform_test)
return train_dataset, test_dataset

def collate_fn(batch):
images = torch.stack([b[0] for b in batch])
if trainer.fp16:
images = images.half()
labels = [b[1] for b in batch]
labels = torch.tensor(labels).long()
return {"images": images, "labels": labels}

def validate(logits, labels, meta=None):
_, predicted = logits.max(1)
total = labels.size(0)
correct = predicted.eq(labels).sum().item()
return correct / total

if __name__ == '__main__':
loader = AutoLoader(task_name="classification",
model_name="vit-base-p16-224",
num_classes=100)

model = loader.get_model()
# optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, n_epochs)
train_dataset, val_dataset = build_cifar()

trainer.train(model,
# optimizer=optimizer,
# lr_scheduler=scheduler,
train_dataset=train_dataset,
valid_dataset=val_dataset,
metric_methods=[["accuracy", validate]],
collate_fn=collate_fn)





76 changes: 76 additions & 0 deletions validate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import torch
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
from flagai.auto_model.auto_loader import AutoLoader
import os
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def build_cifar():

transform_test = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

test_dataset = CIFAR100(root="./cifar100", train=False, download=True, transform=transform_test)
return test_dataset

def collate_fn(batch):
images = torch.stack([b[0] for b in batch])
labels = [b[1] for b in batch]
labels = torch.tensor(labels).long()
return {"images": images, "labels": labels}

def validate(logits, labels, meta=None):
_, predicted = logits.max(1)
total = labels.size(0)
correct = predicted.eq(labels).sum().item()
return correct / total

if __name__ == '__main__':

model_save_dir = "./checkpoints_vit_cifar100"
print(f"loadding model in :{model_save_dir}")
loader = AutoLoader(task_name="classification",
model_name="vit-base-p16-224",
num_classes=100)

model = loader.get_model()

model.load_state_dict(torch.load(os.path.join(model_save_dir, "38000", "pytorch_model.bin"), map_location=device)["module"])
print(f"model load success.......")
model.to(device)

val_dataset = build_cifar()

val_dataloader = DataLoader(val_dataset,
batch_size=1,
shuffle=False,
collate_fn=collate_fn)
index = 0
accuracy = 0.0
for data in tqdm(val_dataloader, total=len(val_dataloader)):
index += 1
data = {k: v.to(device) for k, v in data.items()}
labels = data["labels"]
pred = model(**data)["logits"]
acc = validate(pred, labels)
accuracy += acc

print(f"accuracy is {accuracy / index}")