Skip to content

Commit

Permalink
Publish train scripts
Browse files Browse the repository at this point in the history
PeterL1n committed Mar 7, 2021
1 parent 999fa88 commit d993eaa
Showing 5 changed files with 646 additions and 15 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -28,6 +28,7 @@ Official repository for the paper [Real-Time High-Resolution Background Matting]

## Updates

* [Mar 06 2021] Training script is published.
* [Feb 28 2021] Paper is accepted to CVPR 2021.
* [Jan 09 2021] PhotoMatte85 dataset is now published.
* [Dec 21 2020] We updated our project to MIT License, which permits commercial use.
@@ -48,8 +49,9 @@ Official repository for the paper [Real-Time High-Resolution Background Matting]

### Datasets

* VideoMatte240K (Coming soon)
* [PhotoMatte85](https://drive.google.com/file/d/1KpHKYW986Dax9-ZIM7I-HyBoWVcLPuaQ/view?usp=sharing)
* VideoMatte240K (We are still dealing with licensing. In the meantime, you can visit [storyblocks.com](https://www.storyblocks.com/video/search/green+screen+human?max_duration=10000&sort=most_relevant&video_quality=HD) to download raw green screen videos and recreate the dataset yourself.)


 

@@ -85,7 +87,7 @@ You can run our model using **PyTorch**, **TorchScript**, **TensorFlow**, and **

## Training

Training code will be released upon acceptance of the paper.
Configure `data_path.pth` to point to your dataset. The original paper uses `train_base.pth` to train only the base model till convergence then use `train_refine.pth` to train the entire network end-to-end. More details are specified in the paper.

 

68 changes: 68 additions & 0 deletions data_path.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""
This file records the directory paths to the different datasets.
You will need to configure it for training the model.
All datasets follows the following format, where fgr and pha points to directory that contains jpg or png.
Inside the directory could be any nested formats, but fgr and pha structure must match. You can add your own
dataset to the list as long as it follows the format. 'fgr' should point to foreground images with RGB channels,
'pha' should point to alpha images with only 1 grey channel.
{
'YOUR_DATASET': {
'train': {
'fgr': 'PATH_TO_IMAGES_DIR',
'pha': 'PATH_TO_IMAGES_DIR',
},
'valid': {
'fgr': 'PATH_TO_IMAGES_DIR',
'pha': 'PATH_TO_IMAGES_DIR',
}
}
}
"""

DATA_PATH = {
'videomatte240k': {
'train': {
'fgr': 'PATH_TO_IMAGES_DIR',
'pha': 'PATH_TO_IMAGES_DIR'
},
'valid': {
'fgr': 'PATH_TO_IMAGES_DIR',
'pha': 'PATH_TO_IMAGES_DIR'
}
},
'photomatte13k': {
'train': {
'fgr': 'PATH_TO_IMAGES_DIR',
'pha': 'PATH_TO_IMAGES_DIR'
},
'valid': {
'fgr': 'PATH_TO_IMAGES_DIR',
'pha': 'PATH_TO_IMAGES_DIR'
}
},
'distinction': {
'train': {
'fgr': 'PATH_TO_IMAGES_DIR',
'pha': 'PATH_TO_IMAGES_DIR',
},
'valid': {
'fgr': 'PATH_TO_IMAGES_DIR',
'pha': 'PATH_TO_IMAGES_DIR'
},
},
'adobe': {
'train': {
'fgr': 'PATH_TO_IMAGES_DIR',
'pha': 'PATH_TO_IMAGES_DIR',
},
'valid': {
'fgr': 'PATH_TO_IMAGES_DIR',
'pha': 'PATH_TO_IMAGES_DIR'
},
},
'backgrounds': {
'train': 'PATH_TO_IMAGES_DIR',
'valid': 'PATH_TO_IMAGES_DIR'
},
}
13 changes: 0 additions & 13 deletions dataset/augmentation.py
Original file line number Diff line number Diff line change
@@ -52,19 +52,6 @@ def __call__(self, *x):
return [F.affine(xi, *param, resamples[i], self.fillcolor) for i, xi in enumerate(x)]


class PairRandomResizedCrop(T.RandomResizedCrop):
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolations=None):
super().__init__(size, scale, ratio, Image.BILINEAR)
self.interpolations = interpolations

def __call__(self, *x):
if not len(x):
return []
i, j, h, w = self.get_params(x[0], self.scale, self.ratio)
interpolations = self.interpolations or [self.interpolation] * len(x)
return [F.resized_crop(xi, i, j, h, w, self.size, interpolations[i]) for i, xi in enumerate(x)]


class PairRandomHorizontalFlip(T.RandomHorizontalFlip):
def __call__(self, *x):
if torch.rand(1) < self.p:
265 changes: 265 additions & 0 deletions train_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
"""
Train MattingBase
You can download pretrained DeepLabV3 weights from <https://github.com/VainF/DeepLabV3Plus-Pytorch>
Example:
CUDA_VISIBLE_DEVICES=0 python train_base.py \
--dataset-name videomatte240k \
--model-backbone resnet50 \
--model-name mattingbase-resnet50-videomatte240k \
--model-pretrain-initialization "pretraining/best_deeplabv3_resnet50_voc_os16.pth" \
--epoch-end 8
"""

import argparse
import kornia
import torch
import os
import random

from torch import nn
from torch.nn import functional as F
from torch.cuda.amp import autocast, GradScaler
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torch.optim import Adam
from torchvision.utils import make_grid
from tqdm import tqdm
from torchvision import transforms as T
from PIL import Image

from data_path import DATA_PATH
from dataset import ImagesDataset, ZipDataset, VideoDataset, SampleDataset
from dataset import augmentation as A
from model import MattingBase
from model.utils import load_matched_state_dict


# --------------- Arguments ---------------


parser = argparse.ArgumentParser()

parser.add_argument('--dataset-name', type=str, required=True, choices=DATA_PATH.keys())

parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
parser.add_argument('--model-name', type=str, required=True)
parser.add_argument('--model-pretrain-initialization', type=str, default=None)
parser.add_argument('--model-last-checkpoint', type=str, default=None)

parser.add_argument('--batch-size', type=int, default=8)
parser.add_argument('--num-workers', type=int, default=16)
parser.add_argument('--epoch-start', type=int, default=0)
parser.add_argument('--epoch-end', type=int, required=True)

parser.add_argument('--log-train-loss-interval', type=int, default=10)
parser.add_argument('--log-train-images-interval', type=int, default=2000)
parser.add_argument('--log-valid-interval', type=int, default=5000)

parser.add_argument('--checkpoint-interval', type=int, default=5000)

args = parser.parse_args()


# --------------- Loading ---------------


def train():

# Training DataLoader
dataset_train = ZipDataset([
ZipDataset([
ImagesDataset(DATA_PATH[args.dataset_name]['train']['pha'], mode='L'),
ImagesDataset(DATA_PATH[args.dataset_name]['train']['fgr'], mode='RGB'),
], transforms=A.PairCompose([
A.PairRandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.4, 1), shear=(-5, 5)),
A.PairRandomHorizontalFlip(),
A.PairRandomBoxBlur(0.1, 5),
A.PairRandomSharpen(0.1),
A.PairApplyOnlyAtIndices([1], T.ColorJitter(0.15, 0.15, 0.15, 0.05)),
A.PairApply(T.ToTensor())
]), assert_equal_length=True),
ImagesDataset(DATA_PATH['backgrounds']['train'], mode='RGB', transforms=T.Compose([
A.RandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 2), shear=(-5, 5)),
T.RandomHorizontalFlip(),
A.RandomBoxBlur(0.1, 5),
A.RandomSharpen(0.1),
T.ColorJitter(0.15, 0.15, 0.15, 0.05),
T.ToTensor()
])),
])
dataloader_train = DataLoader(dataset_train,
shuffle=True,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=True)

# Validation DataLoader
dataset_valid = ZipDataset([
ZipDataset([
ImagesDataset(DATA_PATH[args.dataset_name]['valid']['pha'], mode='L'),
ImagesDataset(DATA_PATH[args.dataset_name]['valid']['fgr'], mode='RGB')
], transforms=A.PairCompose([
A.PairRandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)),
A.PairApply(T.ToTensor())
]), assert_equal_length=True),
ImagesDataset(DATA_PATH['backgrounds']['valid'], mode='RGB', transforms=T.Compose([
A.RandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 1.2), shear=(-5, 5)),
T.ToTensor()
])),
])
dataset_valid = SampleDataset(dataset_valid, 50)
dataloader_valid = DataLoader(dataset_valid,
pin_memory=True,
batch_size=args.batch_size,
num_workers=args.num_workers)

# Model
model = MattingBase(args.model_backbone).cuda()

if args.model_last_checkpoint is not None:
load_matched_state_dict(model, torch.load(args.model_last_checkpoint))
elif args.model_pretrain_initialization is not None:
model.load_pretrained_deeplabv3_state_dict(torch.load(args.model_pretrain_initialization)['model_state'])

optimizer = Adam([
{'params': model.backbone.parameters(), 'lr': 1e-4},
{'params': model.aspp.parameters(), 'lr': 5e-4},
{'params': model.decoder.parameters(), 'lr': 5e-4}
])
scaler = GradScaler()

# Logging and checkpoints
if not os.path.exists(f'checkpoint/{args.model_name}'):
os.makedirs(f'checkpoint/{args.model_name}')
writer = SummaryWriter(f'log/{args.model_name}')

# Run loop
for epoch in range(args.epoch_start, args.epoch_end):
for i, ((true_pha, true_fgr), true_bgr) in enumerate(tqdm(dataloader_train)):
step = epoch * len(dataloader_train) + i

true_pha = true_pha.cuda(non_blocking=True)
true_fgr = true_fgr.cuda(non_blocking=True)
true_bgr = true_bgr.cuda(non_blocking=True)
true_pha, true_fgr, true_bgr = random_crop(true_pha, true_fgr, true_bgr)

true_src = true_bgr.clone()

# Augment with shadow
aug_shadow_idx = torch.rand(len(true_src)) < 0.3
if aug_shadow_idx.any():
aug_shadow = true_pha[aug_shadow_idx].mul(0.3 * random.random())
aug_shadow = T.RandomAffine(degrees=(-5, 5), translate=(0.2, 0.2), scale=(0.5, 1.5), shear=(-5, 5))(aug_shadow)
aug_shadow = kornia.filters.box_blur(aug_shadow, (random.choice(range(20, 40)),) * 2)
true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_(aug_shadow).clamp_(0, 1)
del aug_shadow
del aug_shadow_idx

# Composite foreground onto source
true_src = true_fgr * true_pha + true_src * (1 - true_pha)

# Augment with noise
aug_noise_idx = torch.rand(len(true_src)) < 0.4
if aug_noise_idx.any():
true_src[aug_noise_idx] = true_src[aug_noise_idx].add_(torch.randn_like(true_src[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_(torch.randn_like(true_bgr[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
del aug_noise_idx

# Augment background with jitter
aug_jitter_idx = torch.rand(len(true_src)) < 0.8
if aug_jitter_idx.any():
true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter(0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx])
del aug_jitter_idx

# Augment background with affine
aug_affine_idx = torch.rand(len(true_bgr)) < 0.3
if aug_affine_idx.any():
true_bgr[aug_affine_idx] = T.RandomAffine(degrees=(-1, 1), translate=(0.01, 0.01))(true_bgr[aug_affine_idx])
del aug_affine_idx

with autocast():
pred_pha, pred_fgr, pred_err = model(true_src, true_bgr)[:3]
loss = compute_loss(pred_pha, pred_fgr, pred_err, true_pha, true_fgr)

scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()

if (i + 1) % args.log_train_loss_interval == 0:
writer.add_scalar('loss', loss, step)

if (i + 1) % args.log_train_images_interval == 0:
writer.add_image('train_pred_pha', make_grid(pred_pha, nrow=5), step)
writer.add_image('train_pred_fgr', make_grid(pred_fgr, nrow=5), step)
writer.add_image('train_pred_com', make_grid(pred_fgr * pred_pha, nrow=5), step)
writer.add_image('train_pred_err', make_grid(pred_err, nrow=5), step)
writer.add_image('train_true_src', make_grid(true_src, nrow=5), step)
writer.add_image('train_true_bgr', make_grid(true_bgr, nrow=5), step)

del true_pha, true_fgr, true_bgr
del pred_pha, pred_fgr, pred_err

if (i + 1) % args.log_valid_interval == 0:
valid(model, dataloader_valid, writer, step)

if (step + 1) % args.checkpoint_interval == 0:
torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}-iter-{step}.pth')

torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}.pth')


# --------------- Utils ---------------


def compute_loss(pred_pha, pred_fgr, pred_err, true_pha, true_fgr):
true_err = torch.abs(pred_pha.detach() - true_pha)
true_msk = true_pha != 0
return F.l1_loss(pred_pha, true_pha) + \
F.l1_loss(kornia.sobel(pred_pha), kornia.sobel(true_pha)) + \
F.l1_loss(pred_fgr * true_msk, true_fgr * true_msk) + \
F.mse_loss(pred_err, true_err)


def random_crop(*imgs):
w = random.choice(range(256, 512))
h = random.choice(range(256, 512))
results = []
for img in imgs:
img = kornia.resize(img, (max(h, w), max(h, w)))
img = kornia.center_crop(img, (h, w))
results.append(img)
return results


def valid(model, dataloader, writer, step):
model.eval()
loss_total = 0
loss_count = 0
with torch.no_grad():
for (true_pha, true_fgr), true_bgr in dataloader:
batch_size = true_pha.size(0)

true_pha = true_pha.cuda(non_blocking=True)
true_fgr = true_fgr.cuda(non_blocking=True)
true_bgr = true_bgr.cuda(non_blocking=True)
true_src = true_pha * true_fgr + (1 - true_pha) * true_bgr

pred_pha, pred_fgr, pred_err = model(true_src, true_bgr)[:3]
loss = compute_loss(pred_pha, pred_fgr, pred_err, true_pha, true_fgr)
loss_total += loss.cpu().item() * batch_size
loss_count += batch_size

writer.add_scalar('valid_loss', loss_total / loss_count, step)
model.train()


# --------------- Start ---------------


if __name__ == '__main__':
train()
309 changes: 309 additions & 0 deletions train_refine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,309 @@
"""
Train MattingRefine
Supports multi-GPU training with DistributedDataParallel() and SyncBatchNorm.
Select GPUs through CUDA_VISIBLE_DEVICES environment variable.
Example:
CUDA_VISIBLE_DEVICES=0,1 python train_refine.py \
--dataset-name videomatte240k \
--model-backbone resnet50 \
--model-name mattingrefine-resnet50-videomatte240k \
--model-last-checkpoint "PATH_TO_LAST_CHECKPOINT" \
--epoch-end 1
"""

import argparse
import kornia
import torch
import os
import random

from torch import nn
from torch import distributed as dist
from torch import multiprocessing as mp
from torch.nn import functional as F
from torch.cuda.amp import autocast, GradScaler
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader, Subset
from torch.optim import Adam
from torchvision.utils import make_grid
from tqdm import tqdm
from torchvision import transforms as T
from PIL import Image

from data_path import DATA_PATH
from dataset import ImagesDataset, ZipDataset, VideoDataset, SampleDataset
from dataset import augmentation as A
from model import MattingRefine
from model.utils import load_matched_state_dict


# --------------- Arguments ---------------


parser = argparse.ArgumentParser()

parser.add_argument('--dataset-name', type=str, required=True, choices=DATA_PATH.keys())

parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
parser.add_argument('--model-backbone-scale', type=float, default=0.25)
parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
parser.add_argument('--model-refine-thresholding', type=float, default=0.7)
parser.add_argument('--model-refine-kernel-size', type=int, default=3, choices=[1, 3])
parser.add_argument('--model-name', type=str, required=True)
parser.add_argument('--model-last-checkpoint', type=str, default=None)

parser.add_argument('--batch-size', type=int, default=4)
parser.add_argument('--num-workers', type=int, default=16)
parser.add_argument('--epoch-start', type=int, default=0)
parser.add_argument('--epoch-end', type=int, required=True)

parser.add_argument('--log-train-loss-interval', type=int, default=10)
parser.add_argument('--log-train-images-interval', type=int, default=1000)
parser.add_argument('--log-valid-interval', type=int, default=2000)

parser.add_argument('--checkpoint-interval', type=int, default=2000)

args = parser.parse_args()


distributed_num_gpus = torch.cuda.device_count()
assert args.batch_size % distributed_num_gpus == 0


# --------------- Main ---------------

def train_worker(rank, addr, port):

# Distributed Setup
os.environ['MASTER_ADDR'] = addr
os.environ['MASTER_PORT'] = port
dist.init_process_group("nccl", rank=rank, world_size=distributed_num_gpus)

# Training DataLoader
dataset_train = ZipDataset([
ZipDataset([
ImagesDataset(DATA_PATH[args.dataset_name]['train']['pha'], mode='L'),
ImagesDataset(DATA_PATH[args.dataset_name]['train']['fgr'], mode='RGB'),
], transforms=A.PairCompose([
A.PairRandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)),
A.PairRandomHorizontalFlip(),
A.PairRandomBoxBlur(0.1, 5),
A.PairRandomSharpen(0.1),
A.PairApplyOnlyAtIndices([1], T.ColorJitter(0.15, 0.15, 0.15, 0.05)),
A.PairApply(T.ToTensor())
]), assert_equal_length=True),
ImagesDataset(DATA_PATH['backgrounds']['train'], mode='RGB', transforms=T.Compose([
A.RandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 2), shear=(-5, 5)),
T.RandomHorizontalFlip(),
A.RandomBoxBlur(0.1, 5),
A.RandomSharpen(0.1),
T.ColorJitter(0.15, 0.15, 0.15, 0.05),
T.ToTensor()
])),
])
dataset_train_len_per_gpu_worker = int(len(dataset_train) / distributed_num_gpus)
dataset_train = Subset(dataset_train, range(rank * dataset_train_len_per_gpu_worker, (rank + 1) * dataset_train_len_per_gpu_worker))
dataloader_train = DataLoader(dataset_train,
shuffle=True,
pin_memory=True,
drop_last=True,
batch_size=args.batch_size // distributed_num_gpus,
num_workers=args.num_workers // distributed_num_gpus)

# Validation DataLoader
if rank == 0:
dataset_valid = ZipDataset([
ZipDataset([
ImagesDataset(DATA_PATH[args.dataset_name]['valid']['pha'], mode='L'),
ImagesDataset(DATA_PATH[args.dataset_name]['valid']['fgr'], mode='RGB')
], transforms=A.PairCompose([
A.PairRandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)),
A.PairApply(T.ToTensor())
]), assert_equal_length=True),
ImagesDataset(DATA_PATH['backgrounds']['valid'], mode='RGB', transforms=T.Compose([
A.RandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 1.2), shear=(-5, 5)),
T.ToTensor()
])),
])
dataset_valid = SampleDataset(dataset_valid, 50)
dataloader_valid = DataLoader(dataset_valid,
pin_memory=True,
drop_last=True,
batch_size=args.batch_size // distributed_num_gpus,
num_workers=args.num_workers // distributed_num_gpus)

# Model
model = MattingRefine(args.model_backbone,
args.model_backbone_scale,
args.model_refine_mode,
args.model_refine_sample_pixels,
args.model_refine_thresholding,
args.model_refine_kernel_size).to(rank)
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
model_distributed = nn.parallel.DistributedDataParallel(model, device_ids=[rank])

if args.model_last_checkpoint is not None:
load_matched_state_dict(model, torch.load(args.model_last_checkpoint))

optimizer = Adam([
{'params': model.backbone.parameters(), 'lr': 5e-5},
{'params': model.aspp.parameters(), 'lr': 5e-5},
{'params': model.decoder.parameters(), 'lr': 1e-4},
{'params': model.refiner.parameters(), 'lr': 3e-4},
])
scaler = GradScaler()

# Logging and checkpoints
if rank == 0:
if not os.path.exists(f'checkpoint/{args.model_name}'):
os.makedirs(f'checkpoint/{args.model_name}')
writer = SummaryWriter(f'log/{args.model_name}')

# Run loop
for epoch in range(args.epoch_start, args.epoch_end):
for i, ((true_pha, true_fgr), true_bgr) in enumerate(tqdm(dataloader_train)):
step = epoch * len(dataloader_train) + i

true_pha = true_pha.to(rank, non_blocking=True)
true_fgr = true_fgr.to(rank, non_blocking=True)
true_bgr = true_bgr.to(rank, non_blocking=True)
true_pha, true_fgr, true_bgr = random_crop(true_pha, true_fgr, true_bgr)

true_src = true_bgr.clone()

# Augment with shadow
aug_shadow_idx = torch.rand(len(true_src)) < 0.3
if aug_shadow_idx.any():
aug_shadow = true_pha[aug_shadow_idx].mul(0.3 * random.random())
aug_shadow = T.RandomAffine(degrees=(-5, 5), translate=(0.2, 0.2), scale=(0.5, 1.5), shear=(-5, 5))(aug_shadow)
aug_shadow = kornia.filters.box_blur(aug_shadow, (random.choice(range(20, 40)),) * 2)
true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_(aug_shadow).clamp_(0, 1)
del aug_shadow
del aug_shadow_idx

# Composite foreground onto source
true_src = true_fgr * true_pha + true_src * (1 - true_pha)

# Augment with noise
aug_noise_idx = torch.rand(len(true_src)) < 0.4
if aug_noise_idx.any():
true_src[aug_noise_idx] = true_src[aug_noise_idx].add_(torch.randn_like(true_src[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_(torch.randn_like(true_bgr[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
del aug_noise_idx

# Augment background with jitter
aug_jitter_idx = torch.rand(len(true_src)) < 0.8
if aug_jitter_idx.any():
true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter(0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx])
del aug_jitter_idx

# Augment background with affine
aug_affine_idx = torch.rand(len(true_bgr)) < 0.3
if aug_affine_idx.any():
true_bgr[aug_affine_idx] = T.RandomAffine(degrees=(-1, 1), translate=(0.01, 0.01))(true_bgr[aug_affine_idx])
del aug_affine_idx

with autocast():
pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, _ = model_distributed(true_src, true_bgr)
loss = compute_loss(pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha, true_fgr)

scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()

if rank == 0:
if (i + 1) % args.log_train_loss_interval == 0:
writer.add_scalar('loss', loss, step)

if (i + 1) % args.log_train_images_interval == 0:
writer.add_image('train_pred_pha', make_grid(pred_pha, nrow=5), step)
writer.add_image('train_pred_fgr', make_grid(pred_fgr, nrow=5), step)
writer.add_image('train_pred_com', make_grid(pred_fgr * pred_pha, nrow=5), step)
writer.add_image('train_pred_err', make_grid(pred_err_sm, nrow=5), step)
writer.add_image('train_true_src', make_grid(true_src, nrow=5), step)

del true_pha, true_fgr, true_src, true_bgr
del pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm

if (i + 1) % args.log_valid_interval == 0:
valid(model, dataloader_valid, writer, step)

if (step + 1) % args.checkpoint_interval == 0:
torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}-iter-{step}.pth')

if rank == 0:
torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}.pth')

# Clean up
dist.destroy_process_group()


# --------------- Utils ---------------


def compute_loss(pred_pha_lg, pred_fgr_lg, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha_lg, true_fgr_lg):
true_pha_sm = kornia.resize(true_pha_lg, pred_pha_sm.shape[2:])
true_fgr_sm = kornia.resize(true_fgr_lg, pred_fgr_sm.shape[2:])
true_msk_lg = true_pha_lg != 0
true_msk_sm = true_pha_sm != 0
return F.l1_loss(pred_pha_lg, true_pha_lg) + \
F.l1_loss(pred_pha_sm, true_pha_sm) + \
F.l1_loss(kornia.sobel(pred_pha_lg), kornia.sobel(true_pha_lg)) + \
F.l1_loss(kornia.sobel(pred_pha_sm), kornia.sobel(true_pha_sm)) + \
F.l1_loss(pred_fgr_lg * true_msk_lg, true_fgr_lg * true_msk_lg) + \
F.l1_loss(pred_fgr_sm * true_msk_sm, true_fgr_sm * true_msk_sm) + \
F.mse_loss(kornia.resize(pred_err_sm, true_pha_lg.shape[2:]), \
kornia.resize(pred_pha_sm, true_pha_lg.shape[2:]).sub(true_pha_lg).abs())


def random_crop(*imgs):
H_src, W_src = imgs[0].shape[2:]
W_tgt = random.choice(range(1024, 2048)) // 4 * 4
H_tgt = random.choice(range(1024, 2048)) // 4 * 4
scale = max(W_tgt / W_src, H_tgt / H_src)
results = []
for img in imgs:
img = kornia.resize(img, (int(H_src * scale), int(W_src * scale)))
img = kornia.center_crop(img, (H_tgt, W_tgt))
results.append(img)
return results


def valid(model, dataloader, writer, step):
model.eval()
loss_total = 0
loss_count = 0
with torch.no_grad():
for (true_pha, true_fgr), true_bgr in dataloader:
batch_size = true_pha.size(0)

true_pha = true_pha.cuda(non_blocking=True)
true_fgr = true_fgr.cuda(non_blocking=True)
true_bgr = true_bgr.cuda(non_blocking=True)
true_src = true_pha * true_fgr + (1 - true_pha) * true_bgr

pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, _ = model(true_src, true_bgr)
loss = compute_loss(pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha, true_fgr)
loss_total += loss.cpu().item() * batch_size
loss_count += batch_size

writer.add_scalar('valid_loss', loss_total / loss_count, step)
model.train()


# --------------- Start ---------------


if __name__ == '__main__':
addr = 'localhost'
port = str(random.choice(range(12300, 12400))) # pick a random port.
mp.spawn(train_worker,
nprocs=distributed_num_gpus,
args=(addr, port),
join=True)

0 comments on commit d993eaa

Please sign in to comment.