Skip to content

Commit

Permalink
Merge branch 'yjyjy131_dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
yjyjy131 committed Jun 26, 2022
2 parents 296bba4 + 0f44ef7 commit b870cb7
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 19 deletions.
8 changes: 7 additions & 1 deletion SPFCN/model/detector.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import dill

from .network import SlotNetwork

Expand All @@ -10,7 +11,12 @@ def __init__(self, device_id: int, **kwargs):
print(self.config)
self.network = SlotNetwork(self.config['dim_encoder'], device_id)
self.network.merge()
self.network.load_state_dict(torch.load(self.config['parameter_path'], map_location=self.device))
try:
self.network.load_state_dict(torch.load(self.config['parameter_path'], map_location=self.device))
except RuntimeError:
net_path = self.config['parameter_path'].replace('.pkl', '.pt')
network = torch.load(net_path, map_location=self.device)
self.network= dill.loads(network)
self.network.eval()

def update_config(self, **kwargs):
Expand Down
10 changes: 5 additions & 5 deletions SPFCN/test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@


@torch.no_grad()
def auto_test(dataset,
network,
device_id: int = 0,
load_path: str = None):
def auto_test(dataset,
network,
device_id: int = 0,
load_path: str = None):
device = torch.device('cpu' if device_id < 0 else 'cuda:%d' % device_id)

try:
Expand All @@ -19,7 +19,7 @@ def auto_test(dataset,
net_path = load_path + '.pt'
assert os.path.exists(net_path)
network = torch.load(net_path, map_location=device)
network=dill.loads(network)
network= dill.loads(network)
network.eval()

auto_tester = Tester(dataset, network, device)
Expand Down
17 changes: 4 additions & 13 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,14 @@
# slot_network_training(data_num=6535, batch_size=10, epoch=10, input_res=224, device_id=0, num_workers=0)

# load trained model
encoder = [32, 44, 64, 92, 128]
net = SlotNetwork(encoder, device_id=0)
model_path = './parameters/merge_bn_epoch10_loss4.pt'
model = torch.load(model_path)
model = dill.loads(model)
model_path = './parameters/merge_bn_epoch10_loss4.pkl'


# auto test
model_path = './SPFCN/'
slot_network_testing(params_path='./parameters/merge_bn_epoch10_loss4', data_num=1500, batch_size=50, input_res=224, device_id=0, num_workers=0)
# slot_network_testing(params_path='./parameters/merge_bn_epoch10_loss4', data_num=1500, batch_size=50, input_res=224, device_id=0, num_workers=0)

# Load detector
config = {
'dim_encoder': encoder,
'parameter_path': model_path,
}
detector = SlotDetector(0, config)
detector = SlotDetector(device_id=0, dim_encoder=[32, 44, 64, 92, 128], parameter_path=model_path)

# Visualize the merge image with result
current_frame = cv2.imread("demo.jpg")
Expand All @@ -45,7 +37,6 @@
cv2.line(current_frame, pt0, pt3, (0, 0, 255), thickness=2)
cv2.line(current_frame, pt1, pt2, (0, 0, 255), thickness=2)
cv2.line(current_frame, pt2, pt3, (0, 0, 255), thickness=2)
cv2.putText(current_frame, "%.2f fps" % infer_fps, (30, 30), cv2.FONT_HERSHEY_COMPLEX, 1.0, (0, 0, 255))
cv2.imwrite("result.jpg", current_frame)


Expand Down

0 comments on commit b870cb7

Please sign in to comment.