-
Notifications
You must be signed in to change notification settings - Fork 421
Finetuning Example #41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
nateraw
wants to merge
15
commits into
facebookresearch:main
Choose a base branch
from
nateraw:nate/finetune
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
559b3b6
:sparkles: introduce finetuning example
nateraw 4c97dfc
:art: improve structure, modularity of code
nateraw 42b3d7c
:lipstick: style
nateraw fb55f1f
:rotating_light: remove unused import
nateraw 7fd0880
:construction: wip
nateraw 0ec28aa
:construction: wip
nateraw e126c63
:pencil: Writing docs.
nateraw e4c8cbf
:art: improve structure + cleanup unnecessary code
nateraw 7df5f3e
:lipstick: apply style
nateraw 9dba4ad
:art: move sampler statement to its own line
nateraw 0886737
:pencil: writing docs
nateraw 8a9f527
:pencil: update docstring with more specific path
nateraw 608c16b
:pencil: add periods to keep it consistent
nateraw cde0ce7
:fire: remove inline comments
nateraw 6608f9a
:fire: removing incomplete finetuning tutorial for now
nateraw File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,344 @@ | ||
import itertools | ||
from pathlib import Path | ||
from random import shuffle | ||
from shutil import unpack_archive | ||
from typing import Tuple | ||
|
||
import pytorch_lightning as pl | ||
import requests | ||
import torch | ||
from pytorchvideo.data import LabeledVideoDataset, make_clip_sampler | ||
from pytorchvideo.data.labeled_video_dataset import labeled_video_dataset | ||
from pytorchvideo.transforms import ( | ||
ApplyTransformToKey, | ||
Normalize, | ||
RandomShortSideScale, | ||
ShortSideScale, | ||
UniformTemporalSubsample, | ||
) | ||
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler | ||
from torchvision.transforms import ( | ||
CenterCrop, | ||
Compose, | ||
Lambda, | ||
RandomCrop, | ||
RandomHorizontalFlip, | ||
) | ||
|
||
|
||
class LabeledVideoDataModule(pl.LightningDataModule): | ||
|
||
SOURCE_URL: str = None | ||
SOURCE_DIR_NAME: str = "" | ||
NUM_CLASSES: int = 700 | ||
VERIFY_SSL: bool = True | ||
|
||
def __init__( | ||
self, | ||
root: str = "./", | ||
clip_duration: int = 2, | ||
video_num_subsampled: int = 8, | ||
video_crop_size: int = 224, | ||
video_means: Tuple[float] = (0.45, 0.45, 0.45), | ||
video_stds: Tuple[float] = (0.225, 0.225, 0.225), | ||
video_min_short_side_scale: int = 256, | ||
video_max_short_side_scale: int = 320, | ||
video_horizontal_flip_p: float = 0.5, | ||
batch_size: int = 4, | ||
workers: int = 4, | ||
**kwargs | ||
): | ||
""" | ||
A LabeledVideoDataModule expects a dataset in the following format: | ||
|
||
/root # Root Folder | ||
├── train # Split Folder | ||
│ ├── archery # Class Folder | ||
│ │ ├── -1q7jA3DXQM_000005_000015.mp4 # Videos | ||
│ │ ├── -5NN5hdIwTc_000036_000046.mp4 | ||
│ │ ... | ||
│ ├── bowling | ||
│ │ ├── -5ExwuF5IUI_000030_000040.mp4 | ||
│ │ ... | ||
│ ├── high_jump | ||
│ │ ├── -5ExwuF5IUI_000030_000040.mp4 | ||
│ │ ... | ||
├── val | ||
│ ├── archery | ||
│ │ ├── -1q7jA3DXQM_000005_000015.mp4 | ||
│ │ ├── -5NN5hdIwTc_000036_000046.mp4 | ||
│ │ ... | ||
│ ├── bowling | ||
│ │ ├── -5ExwuF5IUI_000030_000040.mp4 | ||
│ │ ... | ||
|
||
Args: | ||
root (str, optional): Directory where your dataset is stored. Defaults to "./". | ||
clip_duration (int, optional): Duration of clip samples. Defaults to 2. | ||
video_num_subsampled (int, optional): Number of subsamples to take of individual videos. Defaults to 8. | ||
video_crop_size (int, optional): Size to crop the video to. Defaults to 224. | ||
video_means (Tuple[float], optional): Means used to normalize dataset. Defaults to (0.45, 0.45, 0.45). | ||
video_stds (Tuple[float], optional): Standard deviations used to normalized dataset. Defaults to (0.225, 0.225, 0.225). | ||
video_min_short_side_scale (int, optional): min_size arg passed to pytorchvideo.transforms.RandomShortSideScale. Defaults to 256. | ||
video_max_short_side_scale (int, optional): max_size arg passed to pytorchvideo.transforms.RandomShortSideScale. Defaults to 320. | ||
video_horizontal_flip_p (float, optional): Probability of flipping a training example horizontally. Defaults to 0.5. | ||
batch_size (int, optional): Number of examples per batch. Defaults to 4. | ||
workers (int, optional): Number of DataLoader workers. Defaults to 4. | ||
""" | ||
|
||
super().__init__() | ||
self.root = root | ||
self.data_path = Path(self.root) / self.SOURCE_DIR_NAME | ||
self.clip_duration = clip_duration | ||
self.video_num_subsampled = video_num_subsampled | ||
self.video_crop_size = video_crop_size | ||
self.video_means = video_means | ||
self.video_stds = video_stds | ||
self.video_min_short_side_scale = video_min_short_side_scale | ||
self.video_max_short_side_scale = video_max_short_side_scale | ||
self.video_horizontal_flip_p = video_horizontal_flip_p | ||
self.batch_size = batch_size | ||
self.workers = workers | ||
|
||
# Transforms applied to train dataset. | ||
self.train_transform = ApplyTransformToKey( | ||
key="video", | ||
transform=Compose( | ||
[ | ||
UniformTemporalSubsample(self.video_num_subsampled), | ||
Lambda(lambda x: x / 255.0), | ||
Normalize(self.video_means, self.video_stds), | ||
RandomShortSideScale( | ||
min_size=self.video_min_short_side_scale, | ||
max_size=self.video_max_short_side_scale, | ||
), | ||
RandomCrop(self.video_crop_size), | ||
RandomHorizontalFlip(p=self.video_horizontal_flip_p), | ||
] | ||
), | ||
) | ||
|
||
# Transforms applied on val dataset or for inference. | ||
self.val_transform = ApplyTransformToKey( | ||
key="video", | ||
transform=Compose( | ||
[ | ||
UniformTemporalSubsample(self.video_num_subsampled), | ||
Lambda(lambda x: x / 255.0), | ||
Normalize(self.video_means, self.video_stds), | ||
ShortSideScale(self.video_min_short_side_scale), | ||
CenterCrop(self.video_crop_size), | ||
] | ||
), | ||
) | ||
|
||
def prepare_data(self): | ||
"""Download the dataset if it doesn't already exist. This runs only on rank 0""" | ||
if not (self.SOURCE_URL is None or self.SOURCE_DIR_NAME is None): | ||
if not self.data_path.exists(): | ||
download_and_unzip(self.SOURCE_URL, self.root, verify=self.VERIFY_SSL) | ||
|
||
def train_dataloader(self): | ||
do_use_ddp = self.trainer is not None and self.trainer.use_ddp | ||
self.train_dataset = LimitDataset( | ||
labeled_video_dataset( | ||
data_path=str(Path(self.data_path) / "train"), | ||
clip_sampler=make_clip_sampler("random", self.clip_duration), | ||
transform=self.train_transform, | ||
decode_audio=False, | ||
video_sampler=DistributedSampler if do_use_ddp else RandomSampler, | ||
) | ||
) | ||
return DataLoader( | ||
self.train_dataset, batch_size=self.batch_size, num_workers=self.workers | ||
) | ||
|
||
def val_dataloader(self): | ||
do_use_ddp = self.trainer is not None and self.trainer.use_ddp | ||
self.val_dataset = LimitDataset( | ||
labeled_video_dataset( | ||
data_path=str(Path(self.data_path) / "val"), | ||
clip_sampler=make_clip_sampler("uniform", self.clip_duration), | ||
transform=self.val_transform, | ||
decode_audio=False, | ||
video_sampler=DistributedSampler if do_use_ddp else RandomSampler, | ||
) | ||
) | ||
return DataLoader( | ||
self.val_dataset, batch_size=self.batch_size, num_workers=self.workers | ||
) | ||
|
||
|
||
class UCF11DataModule(LabeledVideoDataModule): | ||
|
||
SOURCE_URL: str = "https://www.crcv.ucf.edu/data/YouTube_DataSet_Annotated.zip" | ||
SOURCE_DIR_NAME: str = "action_youtube_naudio" | ||
NUM_CLASSES: int = 11 | ||
VERIFY_SSL: bool = False | ||
|
||
def __init__(self, **kwargs): | ||
""" | ||
The UCF11 Dataset contains 11 action classes: basketball shooting, biking/cycling, diving, | ||
golf swinging, horse back riding, soccer juggling, swinging, tennis swinging, trampoline jumping, | ||
volleyball spiking, and walking with a dog. | ||
|
||
For each class, the videos are grouped into 25 group/scene folders containing at least 4 video clips each. | ||
The video clips in the same scene folder share some common features, such as the same actor, similar | ||
background, similar viewpoint, and so on. | ||
|
||
The folder structure looks like the following: | ||
|
||
/root/action_youtube_naudio | ||
├── basketball # Class Folder Path | ||
│ ├── v_shooting_01 # Scene/Group Folder Path | ||
│ │ ├── v_shooting_01_01.avi # Video Path | ||
│ │ ├── v_shooting_01_02.avi | ||
│ │ ├── v_shooting_01_03.avi | ||
│ │ ├── ... | ||
│ ├── v_shooting_02 | ||
│ ├── v_shooting_03 | ||
│ ├── ... | ||
│ ... | ||
├── biking | ||
│ ├── v_biking_01 | ||
│ │ ├── v_biking_01_01.avi | ||
│ │ ├── v_biking_01_02.avi | ||
│ │ ├── v_biking_01_03.avi | ||
│ ├── v_biking_02 | ||
│ ├── v_biking_03 | ||
│ ... | ||
... | ||
|
||
We take 80% of all scenes and use the videos within for training. The remaining scenes' videos | ||
are used for validation. We do this so the validation data contains only videos from scenes/actors | ||
that the model has not seen yet. | ||
""" | ||
super().__init__(**kwargs) | ||
|
||
def setup(self, stage: str = None): | ||
"""Set up anything needed for initializing train/val datasets. This runs on all nodes.""" | ||
|
||
# Names of classes to predict. | ||
# Ex. ['basketball', 'biking', 'diving', ...] | ||
self.classes = sorted(x.name for x in self.data_path.glob("*") if x.is_dir()) | ||
|
||
# Mapping from label to class id. | ||
# Ex. {'basketball': 0, 'biking': 1, 'diving': 2, ...} | ||
self.label_to_id = {} | ||
|
||
# A list to hold all available scenes across all classes. | ||
scene_folders = [] | ||
|
||
for class_id, class_name in enumerate(self.classes): | ||
|
||
self.label_to_id[class_name] = class_id | ||
|
||
# The path of a class folder within self.data_path. | ||
# Ex. 'action_youtube_naudio/{basketball|biking|diving|...}' | ||
class_folder = self.data_path / class_name | ||
|
||
# Collect scene folders within this class. | ||
# Ex. 'action_youtube_naudio/basketball/v_shooting_01' | ||
for scene_folder in filter(Path.is_dir, class_folder.glob("v_*")): | ||
scene_folders.append(scene_folder) | ||
|
||
# Randomly shuffle the scene folders before splitting them into train/val. | ||
shuffle(scene_folders) | ||
|
||
# Determine number of scenes in train/validation splits. | ||
self.num_train_scenes = int(0.8 * len(scene_folders)) | ||
self.num_val_scenes = len(scene_folders) - self.num_train_scenes | ||
|
||
# Collect train/val paths to videos within each scene folder. | ||
# Validation only uses videos from scenes not seen by model during training. | ||
self.train_paths = [] | ||
self.val_paths = [] | ||
for i, scene_path in enumerate(scene_folders): | ||
|
||
# The actual name of the class (Ex. 'basketball'). | ||
class_name = scene_path.parent.name | ||
|
||
# Loop over all the videos within the given scene folder. | ||
for video_path in scene_path.glob("*.avi"): | ||
|
||
# Construct a tuple containing (<path to a video>, <dict containing extra attributes/metadata>). | ||
# In our case, we assign the class's ID as 'label'. | ||
labeled_path = (video_path, {"label": self.label_to_id[class_name]}) | ||
|
||
if i < self.num_train_scenes: | ||
self.train_paths.append(labeled_path) | ||
else: | ||
self.val_paths.append(labeled_path) | ||
|
||
def train_dataloader(self): | ||
self.train_dataset = LimitDataset( | ||
LabeledVideoDataset( | ||
self.train_paths, | ||
clip_sampler=make_clip_sampler("random", self.clip_duration), | ||
decode_audio=False, | ||
transform=self.train_transform, | ||
video_sampler=RandomSampler, | ||
) | ||
) | ||
return DataLoader( | ||
self.train_dataset, batch_size=self.batch_size, num_workers=self.workers | ||
) | ||
|
||
def val_dataloader(self): | ||
self.val_dataset = LimitDataset( | ||
LabeledVideoDataset( | ||
self.val_paths, | ||
clip_sampler=make_clip_sampler("uniform", self.clip_duration), | ||
decode_audio=False, | ||
transform=self.val_transform, | ||
video_sampler=RandomSampler, | ||
) | ||
) | ||
return DataLoader( | ||
self.val_dataset, batch_size=self.batch_size, num_workers=self.workers | ||
) | ||
|
||
|
||
def download_and_unzip(url, data_dir="./", verify=True): | ||
"""Download a zip file from a given URL and unpack it within data_dir. | ||
|
||
Args: | ||
url (str): A URL to a zip file. | ||
data_dir (str, optional): Directory where the zip will be unpacked. Defaults to "./". | ||
verify (bool, optional): Whether to verify SSL certificate when requesting the zip file. Defaults to True. | ||
""" | ||
data_dir = Path(data_dir) | ||
zipfile_name = url.split("/")[-1] | ||
data_zip_path = data_dir / zipfile_name | ||
data_dir.mkdir(exist_ok=True, parents=True) | ||
|
||
if not data_zip_path.exists(): | ||
resp = requests.get(url, verify=verify) | ||
|
||
with data_zip_path.open("wb") as f: | ||
f.write(resp.content) | ||
|
||
unpack_archive(data_zip_path, extract_dir=data_dir) | ||
|
||
|
||
class LimitDataset(torch.utils.data.Dataset): | ||
|
||
""" | ||
To ensure a constant number of samples are retrieved from the dataset we use this | ||
LimitDataset wrapper. This is necessary because several of the underlying videos | ||
may be corrupted while fetching or decoding, however, we always want the same | ||
number of steps per epoch. | ||
""" | ||
|
||
def __init__(self, dataset): | ||
super().__init__() | ||
self.dataset = dataset | ||
self.dataset_iter = itertools.chain.from_iterable( | ||
itertools.repeat(iter(dataset), 2) | ||
) | ||
|
||
def __getitem__(self, index): | ||
return next(self.dataset_iter) | ||
|
||
def __len__(self): | ||
return self.dataset.num_videos |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import pytorch_lightning as pl | ||
from data import UCF11DataModule | ||
from models import SlowResnet50LightningModel | ||
from train import parse_args | ||
|
||
|
||
def train(args): | ||
pl.seed_everything(224) | ||
dm = UCF11DataModule(**vars(args)) | ||
model = SlowResnet50LightningModel(num_classes=dm.NUM_CLASSES, **vars(args)) | ||
trainer = pl.Trainer.from_argparse_args(args) | ||
trainer.fit(model, dm) | ||
|
||
|
||
def main(): | ||
args = parse_args() | ||
if args.on_cluster: | ||
from slurm import copy_and_run_with_config | ||
|
||
copy_and_run_with_config( | ||
train, | ||
args, | ||
args.working_directory, | ||
job_name=args.job_name, | ||
time="72:00:00", | ||
partition=args.partition, | ||
gpus_per_node=args.gpus, | ||
ntasks_per_node=args.gpus, | ||
cpus_per_task=10, | ||
mem="470GB", | ||
nodes=args.num_nodes, | ||
constraint="volta32gb", | ||
) | ||
else: # local | ||
train(args) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.