Skip to content

Commit

Permalink
Update test_fastdvdnet.py
Browse files Browse the repository at this point in the history
state_temp_dict = torch.load(args['model_file'])
to
state_temp_dict = torch.load(args['model_file'], map_location=device)
  • Loading branch information
m-tassano authored Dec 14, 2020
1 parent 905d28f commit a842411
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion test_fastdvdnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_fastdvdnet(**args):
model_temp = FastDVDnet(num_input_frames=NUM_IN_FR_EXT)
# Load saved weights
state_temp_dict = torch.load(args['model_file'])
state_temp_dict = torch.load(args['model_file'], map_location=device)
if args['cuda']:
device_ids = [0]
model_temp = nn.DataParallel(model_temp, device_ids=device_ids).cuda()
Expand Down

0 comments on commit a842411

Please sign in to comment.