import torch
import pytorch_lightning as pl
from pytorchvideo.models import create_res_basic_head
from model import Classifier
from data import make_ucf11_datamodule
# Download data, prepare splits
dm = make_ucf11_datamodule()
# Load a model from Torchhub, freeze its backbone, and replace its classification head
model = torch.hub.load("facebookresearch/pytorchvideo", "slow_r50", pretrained=True)
model.blocks[:-1].requires_grad_(False)
model.blocks[-1] = create_res_basic_head(in_features=2048, out_features=dm.num_labels)
# Train w/ PyTorch Lightning
classifier = Classifier(model, lr=2e-4)
trainer = pl.Trainer(gpus=1, precision=16, max_epochs=4)
trainer.fit(classifier, dm)
import torch
import pytorch_lightning as pl
from pytorchvideo.models import create_res_basic_head
from model import Classifier
from data import make_ucf11_datamodule
# Download data, prepare splits
dm = make_ucf11_datamodule()
# Any torch model that accepts video tensors + outputs class predictions
model = ...
# Train w/ PyTorch Lightning
classifier = Classifier(model, lr=2e-4)
trainer = pl.Trainer(gpus=1, precision=16, max_epochs=4)
trainer.fit(classifier, dm)