Skip to content

Commit

Permalink
update code
Browse files Browse the repository at this point in the history
  • Loading branch information
passing2961 committed Feb 13, 2023
1 parent 949b2b3 commit 095d12f
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions modules/empathy_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,13 @@ def __init__(self, opt, batch_size=1, cuda_device=0):
self.model_EX = BiEncoderAttentionWithRationaleClassification()
self.model_ER = BiEncoderAttentionWithRationaleClassification()

IP_weights = torch.load(os.path.join('/path/to/save/dir', 'finetuned_IP.pth'))
IP_weights = torch.load(os.path.join(opt['epitome_save_dir'], 'finetuned_IP.pth'))
self.model_IP.load_state_dict(IP_weights)

EX_weights = torch.load(os.path.join('/path/to/save/dir', 'finetuned_EX.pth'))
EX_weights = torch.load(os.path.join(opt['epitome_save_dir'], 'finetuned_EX.pth'))
self.model_EX.load_state_dict(EX_weights)

ER_weights = torch.load(os.path.join('/path/to/save/dir', 'finetuned_ER.pth'))
ER_weights = torch.load(os.path.join(opt['epitome_save_dir'], 'finetuned_ER.pth'))
self.model_ER.load_state_dict(ER_weights)

#self.use_cuda = not opt['no_cuda'] and torch.cuda.is_available()
Expand Down

0 comments on commit 095d12f

Please sign in to comment.