-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsolve_vg.py
31 lines (23 loc) · 1.04 KB
/
solve_vg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from LPProblem import VGLPProblem
from clevr_block_gen.constraints import get_vg_constraint_map
import torch
import json
from tqdm import tqdm
from multiprocessing import Pool
with open('./results/vg/schema.json', 'r') as f:
schema = json.load(f)
def get_problem(scene):
problem_solver = VGLPProblem(attr_map=schema, opposite=True, person=True, loop=True)
return problem_solver.solve_for_scene(scene)
if __name__ == '__main__':
constraints = get_vg_constraint_map()
with open('./results/vg/probablistic_scenes.pytorch', 'rb') as f:
scenes = torch.load(f)['scenes']
with Pool(10) as p:
predicted_scenes = list(tqdm(p.imap(get_problem, scenes), total=len(scenes)))
# problem_solver = VGLPProblem(attr_map=schema, opposite=True, person=True, transitivity=True)
# predicted_scenes = []
# for scene, problem in tqdm(list(zip(scenes, problems))):
# problem_solver.solve_problem(scene, *problem)
with open('./results/vg/scene_fixed.json', 'w') as f:
json.dump({'scenes': predicted_scenes}, f)