diff --git a/disttest.py b/disttest.py index e375b8c..b2b5359 100644 --- a/disttest.py +++ b/disttest.py @@ -46,8 +46,6 @@ device_id = torch.cuda.current_device() resume_state = torch.load(opt['path']['resume_state'], map_location=lambda storage, loc: storage.cuda(device_id)) -logger.info('Resuming training from epoch: {}.'.format( - resume_state['epoch'])) def corresponding_load(pre_name, state_dict): sub_statedict = {} @@ -56,12 +54,14 @@ def corresponding_load(pre_name, state_dict): sub_statedict[k.replace(pre_name, "")] = v return sub_statedict -resume_state['state_dict'] = corresponding_load('module.', resume_state['state_dict']) +# resume_state['state_dict'] = corresponding_load('module.', resume_state['state_dict']) model.load_state_dict(resume_state['state_dict']) + model = model.cuda() + # testing max_steps = len(test_loader)