-
Notifications
You must be signed in to change notification settings - Fork 81
/
Copy pathinteraction.py
99 lines (79 loc) · 3.35 KB
/
interaction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""
Contains all the types of interaction related to the GUI
Not related to automatic evaluation in the DAVIS dataset
You can inherit the Interaction class to create new interaction types
undo is (sometimes partially) supported
"""
from typing import Tuple
import torch
import torch.nn.functional as F
from gui.click_controller import ClickController
def aggregate_wbg(prob: torch.Tensor, keep_bg: bool = False, hard: bool = False) -> torch.Tensor:
k, h, w = prob.shape
new_prob = torch.cat([torch.prod(1 - prob, dim=0, keepdim=True), prob], 0).clamp(1e-7, 1 - 1e-7)
logits = torch.log((new_prob / (1 - new_prob)))
if hard:
# Very low temperature o((⊙﹏⊙))o 🥶
logits *= 1000
if keep_bg:
return F.softmax(logits, dim=0)
else:
return F.softmax(logits, dim=0)[1:]
class Interaction:
def __init__(self, image: torch.Tensor, prev_mask: torch.Tensor, true_size: Tuple[int, int],
controller: ClickController):
self.image = image
self.prev_mask = prev_mask
self.controller = controller
self.h, self.w = true_size
self.out_prob = None
self.out_mask = None
def predict(self):
pass
class ClickInteraction(Interaction):
def __init__(self, image: torch.Tensor, prev_mask: torch.Tensor, true_size: Tuple[int, int],
controller: ClickController, tar_obj: int):
"""
prev_mask in a prob. form
"""
super().__init__(image, prev_mask, true_size, controller)
self.tar_obj = tar_obj
# negative/positive for each object
self.pos_clicks = []
self.neg_clicks = []
self.first_click = True
self.out_prob = self.prev_mask.clone()
"""
neg - Negative interaction or not
vis - a tuple (visualization map, pass through alpha). None if not needed.
"""
def push_point(self, x: int, y: int, is_neg: bool) -> None:
# Clicks
if is_neg:
self.neg_clicks.append((x, y))
else:
self.pos_clicks.append((x, y))
# Do the prediction
if self.first_click:
last_obj_mask = self.prev_mask[self.tar_obj].unsqueeze(0).unsqueeze(0)
self.obj_mask = self.controller.interact(self.image.unsqueeze(0),
x,
y,
not is_neg,
prev_mask=last_obj_mask)
else:
self.obj_mask = self.controller.interact(self.image.unsqueeze(0),
x,
y,
not is_neg,
prev_mask=None)
if self.first_click:
self.first_click = False
def predict(self) -> torch.Tensor:
self.out_prob = self.prev_mask.clone()
# a small hack to allow the interacting object to overwrite existing masks
# without remembering all the object probabilities
self.out_prob = torch.clamp(self.out_prob, max=0.9)
self.out_prob[self.tar_obj] = self.obj_mask
self.out_prob = aggregate_wbg(self.out_prob[1:], keep_bg=True, hard=True)
return self.out_prob