-
Notifications
You must be signed in to change notification settings - Fork 4
/
run.py
144 lines (120 loc) · 4.68 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from functools import partial
import sys, os
import argparse
import numpy as np
import math
import logging
import torch
from torch.utils.data import DataLoader
from dataset import dataset_dict
from model import build_model
from utils.training_utils import load_config
from utils.checkpoints import load_best_checkpoints
from utils.eval_metric import compute_evaluation_metrics
from utils.generation import generate_meshes, generate_pointclouds, define_userhandle_folder_name
def main(argv):
parser = argparse.ArgumentParser(
description="Train a deformation networks"
)
parser.add_argument(
"config_file",
help="Path to the file that contains the experiment configuration"
)
parser.add_argument(
"--num_workers",
type=int,
default=0,
help="The number of processed spawned by the batch provider"
)
parser.add_argument(
"--num_threads",
type=int,
default=4,
help="The number of threads"
)
args = parser.parse_args(argv)
# Disable trimesh's logger
logging.getLogger("trimesh").setLevel(logging.ERROR)
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
print("Running code on", device)
# Parse the config file
config = load_config(args.config_file)
# Check if output directory exists and if it doesn't create it
output_directory = config["experiment"]["out_dir"]
if not os.path.exists(output_directory):
os.makedirs(output_directory)
# Create an experiment directory using the experiment_name
experiment_name = config["experiment"]["name"]
experiment_directory = os.path.join(
output_directory,
experiment_name
)
if not os.path.exists(experiment_directory):
os.makedirs(experiment_directory)
# Parser dataset
dataset_type = config['data']['type']
Dataset = dataset_dict[dataset_type]
test_dataset = Dataset(
config,
iden_split=config["test"]["iden_split"],
motion_split=config["test"]["motion_split"],
load_mesh=config["test"]["load_mesh"],
num_sampled_pairs=config["test"]["num_sampled_pairs"]
)
test_loader = DataLoader(
test_dataset,
batch_size=config["test"].get("batch_size", 1),
num_workers=args.num_workers,
collate_fn=test_dataset.collate_fn,
shuffle=False,
)
print("Loaded {} test deformation pairs".format( len(test_dataset) ))
# Build the network architecture to be used for training
weight_file = config["test"].get("weight_file")
model, _, _, test_on_batch = build_model(
config, weight_file, device=device
)
# define the output foldername
if config['test']['generate_mesh']:
generation_mesh_directory = os.path.join(
output_directory,
experiment_name,
define_userhandle_folder_name(config),
config['test']['mesh_folder'],
)
if not os.path.exists(generation_mesh_directory):
os.makedirs(generation_mesh_directory)
print("Save generated meshes in {}".format(generation_mesh_directory))
if config['test']['generate_pointcloud']:
generation_pointcloud_directory = os.path.join(
output_directory,
experiment_name,
define_userhandle_folder_name(config),
config['test']['pointcloud_folder'],
)
if not os.path.exists(generation_pointcloud_directory):
os.makedirs(generation_pointcloud_directory)
print("Save generated pointclouds in {}".format(generation_pointcloud_directory))
# Do the inference
print("====> Interactive Editing / Run-batch-processing ====>")
model.eval()
for b, sample in enumerate(test_loader):
# Move everything to device
for k, v in sample.items():
sample[k] = v.to(device)
# run
_, out_dict = test_on_batch(model, sample, config, compute_loss=False)
# get the deformation pair_info of b-th test data samples
sample_idx = out_dict["index"].item()
meta_data = test_dataset.get_metadata(sample_idx)
# generate source / canonical / target meshes and/or point clouds
if config['test']['generate_mesh']:
generate_meshes(generation_mesh_directory, out_dict, meta_data, config['test']['mesh_format'], vert_pred_color=False)
if config['test']['generate_pointcloud']:
generate_pointclouds(generation_pointcloud_directory, out_dict, meta_data, config['test']['pointcloud_format'])
print("====> Interactive Editing / Run-batch-processing ====>")
if __name__ == "__main__":
main(sys.argv[1:])