Skip to content

Commit

Permalink
add docs and inference code
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-zqwang committed Jul 1, 2024
1 parent f60de0e commit e30b9dd
Show file tree
Hide file tree
Showing 19 changed files with 692 additions and 114 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ tools/
scripts/
data/
.vscode/
puzzlefusion_plusplus/auto_aggl.py


__pycache__
*.ipynb
Expand Down
41 changes: 16 additions & 25 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,47 +2,36 @@
<h1 align="center"> PuzzleFusion++: Auto-agglomerative 3D Fracture <br/> Assembly by Denoise and Verify
</h1>



### [Zhengqing Wang*<sup>1</sup>](https://eric-zqwang.com/) , [Jiacheng Chen*<sup>1</sup>](https://jcchen.me) , [Yasutaka Furukawa<sup>1,2</sup>](https://www2.cs.sfu.ca/~furukawa/)

### <sup>1</sup> Simon Fraser University <sup>2</sup> Wayve

### [arXiv](https://arxiv.org/abs/2406.00259), [Project page](https://puzzlefusion-plusplus.github.io/)

</div>



<!-- https://github.com/woodfrog/maptracker/assets/13405255/1c0e072a-cb77-4000-b81b-5b9fd40f8f39 -->




This repository provides the official implementation of the paper [PuzzleFusion++: Auto-agglomerative 3D Fracture Assembly by Denoise and Verify](https://arxiv.org/abs/2406.00259).


## Table of Contents

- [Introduction](#introduction)
- [Model Architecture](#model-architecture)
- [Installation](#installation)
- [Data preparation](#data-preparation)
- [Getting Started](#getting-started)
- [Acknowledgements](#acknowledgements)
- [Getting started](#getting-started)
- [Citation](#citation)
- [License](#license)

## Introduction
This paper proposes a novel “auto-agglomerative” 3D fracture assembly method, PuzzleFusion++, resembling how humans solve challenging spatial puzzles.

Starting from individual fragments, the approach 1) aligns and merges fragments into larger groups akin to agglomerative clustering and 2) repeats the process iteratively in completing the assembly akin to auto-regressive methods. Concretely, a diffusion model denoises the 6-DoF alignment parameters of the fragments simultaneously, and a transformer model verifies and merges pairwise alignments into larger ones, whose process repeats iteratively.

Extensive experiments on the Breaking Bad dataset show that PuzzleFusion++ outperforms all other state-of-the-art techniques by significant margins across all metrics. In particular by over 10% in part accuracy and 50% in Chamfer distance.

This paper proposes a novel “auto-agglomerative” 3D fracture assembly method, PuzzleFusion++, resembling how humans solve challenging spatial puzzles.

## Model Architecture
<div align="center">
<img src="docs/fig/arch.png" width=80% height=80%>
</div>

![visualization](docs/fig/arch.png)
Starting from individual fragments, the approach 1) aligns and merges fragments into larger groups akin to agglomerative clustering and 2) repeats the process iteratively in completing the assembly akin to auto-regressive methods. Concretely, a diffusion model denoises the 6-DoF alignment parameters of the fragments simultaneously (the **Denoiser** in the figure above), and a transformer model verifies and merges pairwise alignments into larger ones (the **Verifier** in the figure above), whose process repeats iteratively.

Extensive experiments on the Breaking Bad dataset show that PuzzleFusion++ outperforms all other state-of-the-art techniques by significant margins across all metrics. In particular by over 10% in part accuracy and 50% in Chamfer distance.


## Installation
Expand All @@ -52,15 +41,14 @@ Please refer to the [installation guide](docs/installation.md) to set up the env

## Data preparation

Please refer to the [data preparation guide](docs/data_preparation.md) to download and prepare for the BreakingBad dataset, as well as downloading our pre-trained model checkpoints.

## Getting Started

## Getting started

<!-- ## Acknowledgements
Please follow the [test guide](docs/test.md) for model inference, evaluation, and visualization.

We're grateful to the open-source projects below, their great work made our project possible:
* [PuzzleFusion](https://github.com/sepidsh/PuzzleFussion)
* [Jigsaw](https://github.com/Jiaxin-Lu/Jigsaw) -->
Please follow the [training guide](docs/training.md) for details about the training pipeline.


## Citation
Expand All @@ -76,6 +64,9 @@ If you find PuzzleFusion++ useful in your research or applications, please consi
}
```

Our method is deeply inspired by [PuzzleFusion](https://github.com/sepidsh/PuzzleFussion) and [Jigsaw](https://github.com/Jiaxin-Lu/Jigsaw), and benefited from their open-source code. Please consider reading these papers if interested in relevant topics.


## License

This project is licensed under GPL, see the [license file](LICENSE) for details.
9 changes: 5 additions & 4 deletions config/ae/data.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
data:
batch_size: 6
val_batch_size: 6
batch_size: 64
val_batch_size: 64
num_workers: 6
data_fn: "everyday.{}.txt"
data_dir: ./pc_data/everyday/train/
data_val_dir: ./pc_data/everyday/val/
data_dir: ./data/pc_data/everyday/train/
data_val_dir: ./data/pc_data/everyday/val/
mesh_data_dir: ../Breaking-Bad-Dataset.github.io/data/
rot_range: -1
overfit: -1
Expand All @@ -14,3 +14,4 @@ data:
min_num_part: 2
max_num_part: 20
shuffle_parts: False
category: all
72 changes: 71 additions & 1 deletion config/auto_aggl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,85 @@ defaults:

denoiser:
ckpt_path: null
data:
val_batch_size: 1
matching_data_path: ./data/matching_data/


verifier:
ckpt_path: null
threshold: 0.9
max_iters: 6

experiment_name: null
train_seed: 123
test_seed: 123
accelerator: gpu

project_root_path: ${hydra:runtime.cwd}
experiment_output_path: ${project_root_path}/output/denoiser/${experiment_name}

inference_dir: null
inference_dir: null


renderer:
output_path: /results/
mesh_path: ../Breaking-Bad-Dataset.github.io/data/
num_samples: 300
duration: 6
extend_endframes: 20
min_parts: 2
max_parts: 20
category: all
material: plastic
save_gt: False
random_sample: True

blender:
imgRes_x: 2048
imgRes_y: 2048
use_GPU: True
exposure: 1.5
numSamples: 200

camera_kwargs:
camera_type: orthographic
fit_camera: False
camPos: [3, 0, 2]
camLookat: [0, 0, 0.5]
camUp: [0, 1, 0]
camHeight: 2.2
resolution: [256, 256]
samples: 32
focalLength: 50

light:
lightAngle: [6, -30, -155]
strength: 2
shadowSoftness: 0.3

render_kwargs:
preview: True
shadow_catcher: False

colors:
- [84, 107, 45]
- [178, 0, 0]
- [135, 206, 234]
- [239, 196, 15]
- [216, 112, 214]
- [255, 127, 79]
- [0, 127, 127]
- [237, 58, 130]
- [196, 237, 0]
- [0, 0, 127]
- [137, 53, 15]
- [112, 127, 142]
- [178, 127, 209]
- [255, 216, 178]
- [127, 127, 0]
- [53, 68, 79]
- [183, 75, 107]
- [70, 72, 107]
- [180, 123, 95]
- [137, 66, 70]
4 changes: 2 additions & 2 deletions config/denoiser/data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ data:
val_batch_size: 64
num_workers: 10
data_fn: "everyday.{}.txt"
data_dir: ./pc_data/everyday/train/
data_val_dir: ./pc_data/everyday/val/
data_dir: ./data/pc_data/everyday/train/
data_val_dir: ./data/pc_data/everyday/val/
rot_range: -1
overfit: -1
min_num_part: 2
Expand Down
2 changes: 1 addition & 1 deletion config/verifier/global_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ data:
batch_size: 64
val_batch_size: 64
num_workers: 4
verifier_data_path: null
verifier_data_path: ./data/verifier_data/
overfit: -1


Expand Down
2 changes: 1 addition & 1 deletion config/verifier/model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ model:
num_bins: 6
embed_dim: 256
num_layers: 6
num_heads: 8
num_heads: 8
53 changes: 53 additions & 0 deletions docs/data_preparation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
## Data preparation
We follow the
[Breaking Bad Dataset](https://breaking-bad-dataset.github.io/) for data pre-processing.
For more information about data processing, please refer to the dataset website.

After processing the data, ensure that you have a folder named `data` with the following structure:
```
data
├── breaking_bad
│ ├── everyday
│ │ ├── BeerBottle
│ │ │ ├── ...
│ │ ├── ...
│ ├── everyday.train.txt
│ ├── everyday.val.txt
│ └── ...
└── ...
```
Only the `everyday` subset is necessary.

### Generate point cloud data
In the orginal benchmark code of Breaking Bad dataset, it needs sample point cloud from mesh in each batch which is time-consuming. We pre-processing the mesh data and generate its point cloud data and its attribute.
```
cd puzzlefusion-plusplus/
python generate_pc_data +data save_pc_data_path=data/pc_data/everyday/
```

### Verifier training data
You can download the verifier data from [here](https://1sfu-my.sharepoint.com/:f:/g/personal/zwa170_sfu_ca/EtSHHinoDndPs8kJfRn_n0QBue1ypoXGkNEOio9pU6bFcQ?e=pkcuox).

### Matching data
You can download the matching data from [here](https://1sfu-my.sharepoint.com/:f:/g/personal/zwa170_sfu_ca/EtSHHinoDndPs8kJfRn_n0QBue1ypoXGkNEOio9pU6bFcQ?e=pkcuox).

The verifier data and matching data need to generate the data from [Jigsaw](https://github.com/Jiaxin-Lu/Jigsaw). Since this process is quite complex, we will upload the processed data for now. More details on how to obtain this processed data will be provided later.

## Checkpoints
We provide the checkpoints at this [link](https://1sfu-my.sharepoint.com/:f:/g/personal/zwa170_sfu_ca/EoYp5Z5WiqtNuq_GOb5Yj1ABSI5lQSXG64StzXb6eTbXNg?e=N3uJ7L). Please download and place them as ./work_dirs/ then unzip.

## Structure
Finally, the overall data structure should looks like:
```
puzzlefusion-plusplus/
├── data
│ ├── pc_data
│ ├── verifier_data
│ ├── matching_data
└── ...
├── output
│ ├── autoencoder
│ ├── denoiser
│ ├── ...
└── ...
```
30 changes: 25 additions & 5 deletions docs/installation.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,37 @@
- ### Installation
**Step 1.** Create conda environment and activate.
### Installation

**Step 1.** Set up conda environment.

```
conda create --name puzzlefusionpp python=3.8 -y
conda activate puzzlefusionpp
```

**Step 2.** Install PyTorch.
```
conda install -y pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia
pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118
```

**Step 3.** Install pytorch3d, torch-cluster, chamferdist packages.
```
# install pytorch3d
git clone https://github.com/facebookresearch/pytorch3d.git
cd pytorch3d && pip install -e .
cd ..
# install torch-cluster
pip install torch-cluster -f https://data.pyg.org/whl/torch-2.0.1+cu118.html
**Step 3.** Install PyTorch3d.
# install chamferdist
git clone https://github.com/krrish94/chamferdist.git
cd chamferdist && python setup.py install
cd ..
```

**Step 4.** Install remaining packages.

pip3 install -r requirements.txt
```
git clone https://github.com/eric-zqwang/puzzlefusion-plusplus.git
cd puzzlefusion-plusplus/
pip3 install -r requirements.txt
```
Empty file removed docs/sampling.md
Empty file.
12 changes: 12 additions & 0 deletions docs/test.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
## Test
We provide the our checkpoints in [data preparation](../docs/data_preparation.md).
You need make sure you download all the data from [data preparation](../docs/data_preparation.md).
We only support batch size equal to one for testing. You need modify the checkpoint path for both pre-trained denoiser and verifier in the script.
```
sh ./scripts/inference.sh
```

The denoising parameter is stored in ./output/denoiser/{experiemnt_name}/inference/{inference_dir}. You can use this saved results to do visualization later.

[Jigsaw](https://github.com/Jiaxin-Lu/Jigsaw) uses sampling by area to generate point cloud data. The point cloud is created using their method, and the matching points are obtained from their network.

5 changes: 2 additions & 3 deletions docs/training.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
## Training

The training consists of three modules as detailed in the paper. We train the models on 4 Nvidia RTX A6000 GPUs.
The training consists of three modules as detailed in the paper. We train the vqvae and denoiser on 4 Nvidia RTX A6000 GPUs. The verifier is trained on a single RTX 4090 GPU.

**Stage 1**: VQVAE:
```
Expand All @@ -11,10 +11,9 @@ sh ./scripts/train_vqvae.sh
```
sh ./sripts/train_denoiser.sh
```
You need modify the checkpoint path for the pre-trained VQVAE in the script.

**Stage 3**: Pairwise alignment verifier:
```
sh ./sripts/train_verifier.sh
```

We also have provided checkpoint for easier testing [here]().
6 changes: 3 additions & 3 deletions generate_pc_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
from tqdm import tqdm

@hydra.main(config_path='config', config_name='vqvae_global_config')
@hydra.main(config_path='config/ae', config_name='global_config.yaml')
def main(cfg):
cfg.data.batch_size = 1
cfg.data.val_batch_size = 1
Expand All @@ -21,12 +21,12 @@ def save_data(loader, data_type):
for i, data_dict in tqdm(enumerate(loader), total=len(loader), desc=f"Processing {data_type} data"):
data_id = data_dict['data_id'][0].item()
part_valids = data_dict['part_valids'][0]
scale = data_dict['scale'][0]
num_parts = data_dict['num_parts'][0].item()
mesh_file_path = data_dict['mesh_file_path'][0]
graph = data_dict['graph'][0]
category = data_dict['category'][0]
part_pcs_gt = data_dict['part_pcs_gt'][0]
ref_part = data_dict['ref_part'][0]

np.savez(
os.path.join(save_path, f'{data_id:05}.npz'),
Expand All @@ -35,9 +35,9 @@ def save_data(loader, data_type):
num_parts=num_parts,
mesh_file_path=mesh_file_path,
graph=graph.cpu().numpy(),
scale=scale,
category=category,
part_pcs_gt=part_pcs_gt.cpu().numpy(),
ref_part=ref_part.cpu().numpy()
)
# print(f"Saved {data_id:05}.npz in {data_type} data.")

Expand Down
Loading

0 comments on commit e30b9dd

Please sign in to comment.