-
Notifications
You must be signed in to change notification settings - Fork 28
/
train.py
319 lines (266 loc) · 14 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
import os
from opt import get_opts
import torch
import numpy as np
from collections import defaultdict
from torch.utils.data import DataLoader
from datasets import dataset_dict
# models
from models.nerf import PosEmbedding, NeRF
from models.rendering import render_rays
# optimizer, scheduler, visualization
from utils import *
from torchvision.utils import make_grid
# losses
from losses import loss_dict
# metrics
import metrics
import third_party.lpips.lpips.lpips as lpips
# pytorch-lightning
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TestTubeLogger
from pytorch_lightning.plugins import DDPPlugin
from pytorch_lightning import seed_everything
seed_everything(42, workers=True)
class NSFFSystem(LightningModule):
def __init__(self, hparams):
super().__init__()
self.save_hyperparameters(hparams)
# losses and metrics
self.loss = \
loss_dict['nerfw'](lambda_geo=self.hparams.lambda_geo_init,
thickness=self.hparams.thickness,
topk=self.hparams.topk)
# models
self.embedding_xyz = PosEmbedding(hparams.S_emb_xyz, hparams.N_emb_xyz)
self.embedding_dir = PosEmbedding(hparams.S_emb_dir, hparams.N_emb_dir)
self.embeddings = {'xyz': self.embedding_xyz, 'dir': self.embedding_dir}
self.N_frames = self.hparams.start_end[1]-self.hparams.start_end[0]
if hparams.encode_a:
self.embedding_a = torch.nn.Embedding(self.N_frames, hparams.N_a)
self.embeddings['a'] = self.embedding_a
load_ckpt(self.embedding_a, hparams.weight_path, 'embedding_a')
if hparams.encode_t:
self.embedding_t = torch.nn.Embedding(self.N_frames, hparams.N_tau)
self.embeddings['t'] = self.embedding_t
load_ckpt(self.embedding_t, hparams.weight_path, 'embedding_t')
self.output_transient = hparams.encode_t
self.output_transient_flow = ['fw', 'bw', 'disocc'] if hparams.encode_t else []
# fine model always exists
self.nerf_fine = NeRF(typ='fine',
in_channels_xyz=6*hparams.N_emb_xyz+3,
use_viewdir=hparams.use_viewdir,
in_channels_dir=6*hparams.N_emb_dir+3,
encode_appearance=hparams.encode_a,
in_channels_a=hparams.N_a,
encode_transient=hparams.encode_t,
in_channels_t=hparams.N_tau,
output_flow=len(self.output_transient_flow)>0,
flow_scale=hparams.flow_scale)
self.models = {'fine': self.nerf_fine}
load_ckpt(self.nerf_fine, hparams.weight_path,
'nerf_fine', hparams.prefixes_to_ignore)
if hparams.N_importance > 0: # coarse to fine
self.nerf_coarse = NeRF(typ='coarse',
in_channels_xyz=6*hparams.N_emb_xyz+3,
use_viewdir=hparams.use_viewdir,
in_channels_dir=6*hparams.N_emb_dir+3,
encode_transient=hparams.encode_t,
in_channels_t=hparams.N_tau)
self.models['coarse'] = self.nerf_coarse
load_ckpt(self.nerf_coarse, hparams.weight_path,
'nerf_coarse', hparams.prefixes_to_ignore)
self.models_to_train = [self.models]
if hparams.encode_a: self.models_to_train += [self.embedding_a]
if hparams.encode_t: self.models_to_train += [self.embedding_t]
def get_progress_bar_dict(self):
items = super().get_progress_bar_dict()
items.pop("v_num", None)
return items
def forward(self, rays, ts, test_time=False, **kwargs):
"""Do batched inference on rays using chunk."""
B = rays.shape[0]
results = defaultdict(list)
kwargs_ = {}
for k, v in kwargs.items(): kwargs_[k] = v
for i in range(0, B, self.hparams.chunk):
rendered_ray_chunks = \
render_rays(self.models,
self.embeddings,
rays[i:i+self.hparams.chunk],
None if ts is None else ts[i:i+self.hparams.chunk],
self.train_dataset.N_frames-1,
self.hparams.N_samples,
self.hparams.perturb if not test_time else 0,
self.hparams.noise_std if not test_time else 0,
self.hparams.N_importance,
self.hparams.chunk//4 if test_time else self.hparams.chunk,
**kwargs_)
for k, v in rendered_ray_chunks.items():
if test_time: v = v.cpu()
results[k] += [v]
for k, v in results.items(): results[k] = torch.cat(v, 0)
return results
def setup(self, stage):
dataset = dataset_dict[self.hparams.dataset_name]
kwargs = {'root_dir': self.hparams.root_dir,
'img_wh': tuple(self.hparams.img_wh),
'start_end': tuple(self.hparams.start_end),
'cache_dir': self.hparams.cache_dir,
'hard_sampling': self.hparams.hard_sampling}
self.train_dataset = dataset(split='train', **kwargs)
self.val_dataset = dataset(split='val', **kwargs)
if self.output_transient_flow:
self.loss.register_buffer('Ks', self.train_dataset.Ks)
self.loss.register_buffer('Ps', self.train_dataset.Ps)
self.loss.max_t = self.N_frames-1
if self.hparams.hard_sampling:
# create buffer to save temporary rgbs
self.register_buffer('tmp_rgb',
torch.zeros(self.N_frames, self.hparams.img_wh[1]*self.hparams.img_wh[0], 3))
def configure_optimizers(self):
kwargs = {}
self.optimizer = get_optimizer(self.hparams, self.models_to_train, **kwargs)
if self.hparams.lr_scheduler == 'const': return self.optimizer
scheduler = get_scheduler(self.hparams, self.optimizer)
return [self.optimizer], [scheduler]
def train_dataloader(self):
self.train_dataset.batch_size = self.hparams.batch_size
return DataLoader(self.train_dataset,
shuffle=True,
num_workers=4,
batch_size=None,
pin_memory=True)
def val_dataloader(self):
return DataLoader(self.val_dataset,
shuffle=False,
num_workers=4,
batch_size=None,
pin_memory=True)
# def on_epoch_start(self):
# # for evaluation TODO: avoid being saved in ckpt...
# if not hasattr(self, 'lpips_model'):
# self.lpips_model = lpips.LPIPS(net='alex', spatial=True)
def on_train_epoch_start(self):
self.loss.lambda_geo_d = self.hparams.lambda_geo_init * 0.1**(self.current_epoch//10)
self.loss.lambda_geo_f = self.hparams.lambda_geo_init * 0.1**(self.current_epoch//10)
def training_step(self, batch, batch_nb):
rays, rgbs, ts = batch['rays'], batch['rgbs'], batch.get('ts', None)
kwargs = {'epoch': self.current_epoch,
'output_transient': self.output_transient,
'output_transient_flow': self.output_transient_flow}
results = self(rays, ts, **kwargs)
if self.hparams.hard_sampling:
self.tmp_rgb[ts, batch['rand_idx']] = results['rgb_fine']
loss_d = self.loss(results, batch, **kwargs)
loss = sum(l for l in loss_d.values())
with torch.no_grad():
psnr_ = metrics.psnr(results['rgb_fine'], rgbs)
self.log('lr', get_learning_rate(self.optimizer))
self.log('train/loss', loss)
for k, v in loss_d.items(): self.log(f'train/{k}', v, prog_bar=True)
self.log('train/psnr', psnr_, prog_bar=True)
return loss
def validation_step(self, batch, batch_nb):
rays, rgbs, ts = batch['rays'], batch['rgbs'], batch.get('ts', None)
batch['rgbs'] = rgbs = rgbs.cpu() # (H*W, 3)
if 'mask' in batch: mask = batch['mask'].cpu() # (H*W)
if 'disp' in batch: disp = batch['disp'].cpu() # (H*W)
kwargs = {'output_transient': self.output_transient,
'output_transient_flow': []}
results = self(rays, ts, test_time=True, **kwargs)
# compute error metrics
W, H = self.hparams.img_wh
img = torch.clip(results['rgb_fine'].view(H, W, 3).cpu(), 0, 1)
img_ = img.permute(2, 0, 1)
img_gt = rgbs.view(H, W, 3).cpu()
rmse_map = ((img_gt-img)**2).mean(-1)**0.5
rmse_map_blend = blend_images(img_, visualize_depth(-rmse_map), 0.5)
ssim_map = metrics.ssim(img_gt, img, reduction='none').mean(-1)
ssim_map_blend = blend_images(img_, visualize_depth(-ssim_map), 0.5)
# lpips_map = metrics.lpips(self.lpips_model, img_gt, img, reduction='none')
# lpips_map_blend = blend_images(img_, visualize_depth(-lpips_map), 0.5)
depth = visualize_depth(results['depth_fine'].view(H, W))
img_list = [img_gt.permute(2, 0, 1), img_, depth]
if self.output_transient:
img_list += [visualize_mask(results['transient_alpha_fine'].view(H, W))]
img_list += [torch.clip(results['_static_rgb_fine'].view(H, W, 3).permute(2, 0, 1).cpu(), 0, 1)]
img_list += [visualize_depth(results['_static_depth_fine'].view(H, W))]
if 'mask' in batch: img_list += [visualize_mask(1-mask.view(H, W))]
if 'disp' in batch: img_list += [visualize_depth(-disp.view(H, W))]
img_grid = make_grid(img_list, nrow=3) # 3 images per row
self.logger.experiment.add_image('reconstruction/decomposition', img_grid, self.global_step)
self.logger.experiment.add_image('error_map/rmse', rmse_map_blend, self.global_step)
self.logger.experiment.add_image('error_map/ssim', ssim_map_blend, self.global_step)
# self.logger.experiment.add_image('error_map/lpips', lpips_map_blend, self.global_step)
log = {'val_psnr': metrics.psnr(results['rgb_fine'], rgbs),
'val_ssim': ssim_map.mean()}
# 'val_lpips': lpips_map.mean()}
if self.output_transient and (mask==0).any():
log['val_psnr_mask'] = metrics.psnr(results['rgb_fine'], rgbs, mask==0)
log['val_ssim_mask'] = ssim_map[mask.view(H, W)==0].mean()
# log['val_lpips_mask'] = lpips_map[mask.view(H, W)==0].mean()
if self.hparams.hard_sampling:
# update weights, indepedent of the above val result (use self.tmp_rgb buffer)
# high ssim = low weight
for i in range(self.N_frames):
img_gt = self.train_dataset.rays_dict[i][:, 6:9].view(H, W, 3)
img = self.tmp_rgb[i].view(H, W, 3).cpu()
_ssim_map = metrics.ssim(img_gt, img, reduction='none').mean(-1)
self.train_dataset.weights[i] = 1-_ssim_map.numpy().flatten()
if i == self.N_frames//2:
_ssim_map_blend = blend_images(img.permute(2, 0, 1), visualize_depth(-_ssim_map), 0.5)
self.logger.experiment.add_image('misc/moving_ssim',
_ssim_map_blend, self.global_step)
return log
def validation_epoch_end(self, outputs):
mean_psnr = torch.stack([x['val_psnr'] for x in outputs]).mean()
mean_ssim = torch.stack([x['val_ssim'] for x in outputs]).mean()
# mean_lpips = torch.stack([x['val_lpips'] for x in outputs]).mean()
self.log('val/psnr', mean_psnr, prog_bar=True)
self.log('val/ssim', mean_ssim)
# self.log('val/lpips', mean_lpips)
if self.output_transient and all(['val_psnr_mask' in x for x in outputs]):
mean_psnr_mask = torch.stack([x['val_psnr_mask'] for x in outputs]).mean()
mean_ssim_mask = torch.stack([x['val_ssim_mask'] for x in outputs]).mean()
# mean_lpips_mask = torch.stack([x['val_lpips_mask'] for x in outputs]).mean()
self.log('val/psnr_mask', mean_psnr_mask, prog_bar=True)
self.log('val/ssim_mask', mean_ssim_mask)
# self.log('val/lpips_mask', mean_lpips_mask)
def main(hparams):
system = NSFFSystem(hparams)
ckpt_cb = ModelCheckpoint(dirpath=f'ckpts/{hparams.exp_name}', filename='{epoch:d}',
save_top_k=-1)
logger = TestTubeLogger(save_dir="logs",
name=hparams.exp_name,
debug=False,
create_git_tag=False,
log_graph=False)
trainer = Trainer(max_epochs=hparams.num_epochs,
callbacks=[ckpt_cb],
resume_from_checkpoint=hparams.ckpt_path,
logger=logger,
weights_summary=None,
progress_bar_refresh_rate=hparams.refresh_every,
gpus=hparams.num_gpus,
num_nodes=hparams.num_nodes,
accelerator='ddp' if hparams.num_gpus>1 else None,
num_sanity_val_steps=1,
reload_dataloaders_every_epoch=True,
benchmark=True,
profiler="simple" if hparams.num_gpus==1 else None,
plugins=[DDPPlugin(find_unused_parameters=False)])
trainer.fit(system)
def backup_files(args, files):
"""Save files for debugging."""
backup_dir = os.path.join('files_backup', args.exp_name)
os.makedirs(backup_dir, exist_ok=True)
for f in files:
os.system(f'cp {f} {backup_dir}')
if __name__ == '__main__':
hparams = get_opts()
if hparams.debug:
backup_files(hparams,
['models/nerf.py', 'models/rendering.py', 'losses.py', 'train.py'])
main(hparams)