Skip to content

Commit

Permalink
add denoiser only
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-zqwang committed Feb 8, 2025
1 parent e2298b6 commit b8779f6
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 11 deletions.
69 changes: 67 additions & 2 deletions puzzlefusion_plusplus/auto_aggl.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def __init__(self, cfg):
super(AutoAgglomerative, self).__init__()
self.cfg = cfg
self.denoiser = DenoiserTransformer(cfg.denoiser)
self.verifier = VerifierTransformer(cfg.verifier)
if cfg.verifier.max_iters > 1:
self.verifier = VerifierTransformer(cfg.verifier)
self.encoder = VQVAE(cfg.ae)

self.save_hyperparameters()
Expand Down Expand Up @@ -91,8 +92,72 @@ def _extract_features(self, part_pcs, part_valids, noisy_trans_and_rots):
xyz[part_valids.bool()] = encoder_out["xyz"]
return latent, xyz


def test_denoiser_only(self, data_dict):
gt_trans = data_dict['part_trans']
gt_rots = data_dict['part_rots']
gt_trans_and_rots = torch.cat([gt_trans, gt_rots], dim=-1)
noisy_trans_and_rots = torch.randn(gt_trans_and_rots.shape, device=self.device)
ref_part = data_dict["ref_part"]

reference_gt_and_rots = torch.zeros_like(gt_trans_and_rots, device=self.device)
reference_gt_and_rots[ref_part] = gt_trans_and_rots[ref_part]

noisy_trans_and_rots[ref_part] = reference_gt_and_rots[ref_part]

part_valids = data_dict['part_valids'].clone()
part_scale = data_dict["part_scale"].clone()
part_pcs = data_dict["part_pcs"].clone()

all_pred_trans_rots = []
for t in self.noise_scheduler.timesteps:
timesteps = t.reshape(-1).repeat(len(noisy_trans_and_rots)).cuda()
latent, xyz = self._extract_features(part_pcs, part_valids, noisy_trans_and_rots)
pred_noise = self.denoiser(
noisy_trans_and_rots,
timesteps,
latent,
xyz,
part_valids,
part_scale,
ref_part
)
noisy_trans_and_rots = self.noise_scheduler.step(pred_noise, t, noisy_trans_and_rots).prev_sample
noisy_trans_and_rots[ref_part] = reference_gt_and_rots[ref_part]
all_pred_trans_rots.append(noisy_trans_and_rots.detach().cpu().numpy())

pts = data_dict['part_pcs']
pred_trans = noisy_trans_and_rots[..., :3]
pred_rots = noisy_trans_and_rots[..., 3:]

expanded_part_scale = data_dict["part_scale"].unsqueeze(-1).expand(-1, -1, 1000, -1)
pts = pts * expanded_part_scale

def test_step(self, data_dict, idx):
acc, _, _ = calc_part_acc(pts, trans1=pred_trans, trans2=gt_trans,
rot1=pred_rots, rot2=gt_rots, valids=data_dict['part_valids'],
chamfer_distance=self.metric)

shape_cd = calc_shape_cd(pts, trans1=pred_trans, trans2=gt_trans,
rot1=pred_rots, rot2=gt_rots, valids=data_dict['part_valids'],
chamfer_distance=self.metric)

rmse_r = rot_metrics(pred_rots, gt_rots, data_dict['part_valids'], 'rmse')
rmse_t = trans_metrics(pred_trans, gt_trans, data_dict['part_valids'], 'rmse')


self.acc_list.append(acc)
self.rmse_r_list.append(rmse_r)
self.rmse_t_list.append(rmse_t)
self.cd_list.append(shape_cd)

self._save_inference_data(data_dict, np.stack(all_pred_trans_rots, axis=0), acc)


def test_step(self, data_dict, idx):
if self.cfg.verifier.max_iters == 1:
self.test_denoiser_only(data_dict)
return

gt_trans = data_dict['part_trans']
gt_rots = data_dict['part_rots']
gt_trans_and_rots = torch.cat([gt_trans, gt_rots], dim=-1)
Expand Down
11 changes: 7 additions & 4 deletions puzzlefusion_plusplus/denoiser/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@ def __init__(
cfg,
data_dir,
overfit,
data_fn
data_fn,
denoiser_only_flag = False
):
self.cfg = cfg
self.mode = data_fn
self.data_dir = data_dir
self.data_files = sorted([f for f in os.listdir(self.data_dir) if f.endswith('.npz')])
self.noise_scheduler = PiecewiseScheduler()
self.max_num_part = self.cfg.data.max_num_part
self.denoiser_only_flag = denoiser_only_flag

if overfit != -1:
self.data_files = self.data_files[:overfit]
Expand Down Expand Up @@ -52,7 +54,7 @@ def __init__(
'graph': graph,
}

if self.mode == "test":
if self.mode == "test" and denoiser_only_flag is False:
matching_data_path = os.path.join(self.matching_data_path, str(data_id) + '.npz')
if not os.path.exists(matching_data_path):
continue
Expand Down Expand Up @@ -188,7 +190,7 @@ def __getitem__(self, idx):
part_pcs_gt = self._pad_data(np.stack(part_pcs_gt, axis=0)).astype(np.float32) # [P, N, 3]


if self.mode == 'test':
if self.mode == 'test' and self.denoiser_only_flag is False:
gt_pc_by_area = self._anchor_coords(
data_dict['gt_pc_by_area'],
pose_gt_t,
Expand Down Expand Up @@ -309,12 +311,13 @@ def build_geometry_dataloader(cfg):
return train_loader, val_loader


def build_test_dataloader(cfg):
def build_test_dataloader(cfg, denoiser_only_flag):
data_dict = dict(
cfg=cfg,
data_dir=cfg.data.data_val_dir,
overfit=cfg.data.overfit,
data_fn="test",
denoiser_only_flag=denoiser_only_flag,
)

val_set = GeometryLatentDataset(**data_dict)
Expand Down
7 changes: 7 additions & 0 deletions scripts/inference_denoiser_only.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
python test.py \
experiment_name=everyday_epoch2000_bs64 \
denoiser.data.val_batch_size=20 \
denoiser.data.data_val_dir=./data/pc_data/everyday/val/ \
denoiser.ckpt_path=output/denoiser/everyday_epoch2000_bs64/training/last.ckpt \
inference_dir=denoiser_only \
verifier.max_iters=1 \
2 changes: 1 addition & 1 deletion scripts/render.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
python renderer/render_results.py \
experiment_name=everyday_epoch2000_gpu4_bs64 \
inference_dir=code_clean_results_iter3 \
inference_dir=denoiser_only \
renderer.num_samples=20 \
renderer.output_path=results \
renderer.blender.imgRes_x=512 \
Expand Down
12 changes: 8 additions & 4 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ def main(cfg):
inference_dir = os.path.join(cfg.experiment_output_path, "inference", cfg.inference_dir)
os.makedirs(inference_dir, exist_ok=True)

denoiser_only_flag = cfg.verifier.max_iters == 1

# initialize data
test_loader = build_test_dataloader(cfg.denoiser)
test_loader = build_test_dataloader(cfg.denoiser, denoiser_only_flag)

# load denoiser weights
model = AutoAgglomerative(cfg)
Expand All @@ -33,9 +35,11 @@ def main(cfg):
if k.startswith('encoder.')}
)

# load verifier weights
verifier_weights = torch.load(cfg.verifier.ckpt_path)['state_dict']
model.verifier.load_state_dict({k.replace('verifier.', ''): v for k, v in verifier_weights.items()})
if cfg.verifier.max_iters > 1:
# load verifier weights
verifier_weights = torch.load(cfg.verifier.ckpt_path)['state_dict']
model.verifier.load_state_dict({k.replace('verifier.', ''): v for k, v in verifier_weights.items()})

# initialize trainer
trainer = pl.Trainer(accelerator=cfg.accelerator, max_epochs=1, logger=False)

Expand Down

0 comments on commit b8779f6

Please sign in to comment.