Skip to content

Commit

Permalink
1.16.19a0
Browse files Browse the repository at this point in the history
  • Loading branch information
yangheng95 committed Oct 12, 2022
1 parent 964d786 commit 64a917b
Show file tree
Hide file tree
Showing 8 changed files with 21 additions and 21 deletions.
2 changes: 1 addition & 1 deletion pyabsa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# Copyright (C) 2021. All Rights Reserved.


__version__ = '1.16.18'
__version__ = '1.16.19a0'

__name__ = 'pyabsa'

Expand Down
2 changes: 1 addition & 1 deletion pyabsa/core/apc/models/ensembler.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(self, opt, load_dataset=True, **kwargs):
if hasattr(APCModelList, models[i].__name__):
try:

if kwargs.pop('offline', False):
if kwargs.get('offline', False):
self.tokenizer = AutoTokenizer.from_pretrained(find_cwd_dir(self.opt.pretrained_bert.split('/')[-1]), do_lower_case='uncased' in self.opt.pretrained_bert)
self.bert = AutoModel.from_pretrained(find_cwd_dir(self.opt.pretrained_bert.split('/')[-1])) if not self.bert else self.bert # share the underlying bert between models
else:
Expand Down
8 changes: 4 additions & 4 deletions pyabsa/core/apc/prediction/sentiment_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(self, model_arg=None, cal_perplexity=False, **kwargs):

with open(config_path, mode='rb') as f:
self.opt = pickle.load(f)
self.opt.device = get_device(kwargs.pop('auto_device', True))[0]
self.opt.device = get_device(kwargs.get('auto_device', True))[0]

if state_dict_path or model_path:
if state_dict_path:
Expand All @@ -85,7 +85,7 @@ def __init__(self, model_arg=None, cal_perplexity=False, **kwargs):
with open(tokenizer_path, mode='rb') as f:
if hasattr(APCModelList, self.opt.model.__name__):
try:
if kwargs.pop('offline', False):
if kwargs.get('offline', False):
self.tokenizer = AutoTokenizer.from_pretrained(find_cwd_dir(self.opt.pretrained_bert.split('/')[-1]), do_lower_case='uncased' in self.opt.pretrained_bert)
else:
self.tokenizer = AutoTokenizer.from_pretrained(self.opt.pretrained_bert, do_lower_case='uncased' in self.opt.pretrained_bert)
Expand All @@ -109,7 +109,7 @@ def __init__(self, model_arg=None, cal_perplexity=False, **kwargs):

self.tokenizer = tokenizer

if kwargs.pop('verbose', False):
if kwargs.get('verbose', False):
print('Config used in Training:')
print_args(self.opt)

Expand Down Expand Up @@ -153,7 +153,7 @@ def __init__(self, model_arg=None, cal_perplexity=False, **kwargs):
# torch.backends.cudnn.benchmark = False

self.opt.initializer = self.opt.initializer
self.opt.eval_batch_size = kwargs.pop('eval_batch_size', 128)
self.opt.eval_batch_size = kwargs.get('eval_batch_size', 128)

if self.cal_perplexity:
try:
Expand Down
10 changes: 5 additions & 5 deletions pyabsa/core/atepc/prediction/aspect_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ def __init__(self, model_arg=None, **kwargs):

with open(config_path, mode='rb') as f:
self.opt = pickle.load(f)
self.opt.device = get_device(kwargs.pop('auto_device', True))[0]
self.opt.device = get_device(kwargs.get('auto_device', True))[0]
if state_dict_path:
try:
if kwargs.pop('offline', False):
if kwargs.get('offline', False):
bert_base_model = AutoModel.from_pretrained(find_cwd_dir(self.opt.pretrained_bert.split('/')[-1]))
else:
bert_base_model = AutoModel.from_pretrained(self.opt.pretrained_bert)
Expand All @@ -76,7 +76,7 @@ def __init__(self, model_arg=None, **kwargs):
self.model = torch.load(model_path, map_location='cpu')
self.model.opt = self.opt
try:
if kwargs.pop('offline', False):
if kwargs.get('offline', False):
self.tokenizer = AutoTokenizer.from_pretrained(find_cwd_dir(self.opt.pretrained_bert.split('/')[-1]))
else:
self.tokenizer = AutoTokenizer.from_pretrained(self.opt.pretrained_bert, do_lower_case='uncased' in self.opt.pretrained_bert)
Expand All @@ -102,7 +102,7 @@ def __init__(self, model_arg=None, **kwargs):
# np.random.seed(self.opt.seed)
# torch.manual_seed(self.opt.seed)

if kwargs.pop('verbose', False):
if kwargs.get('verbose', False):
print('Config used in Training:')
print_args(self.opt)

Expand All @@ -112,7 +112,7 @@ def __init__(self, model_arg=None, **kwargs):
self.opt.gradient_accumulation_steps))

self.eval_dataloader = None
self.opt.eval_batch_size = kwargs.pop('eval_batch_size', 128)
self.opt.eval_batch_size = kwargs.get('eval_batch_size', 128)

self.to(self.opt.device)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

class Tokenizer4Pretraining:
def __init__(self, max_seq_len, opt, **kwargs):
if kwargs.pop('offline', False):
if kwargs.get('offline', False):
self.tokenizer = AutoTokenizer.from_pretrained(find_cwd_dir(opt.pretrained_bert.split('/')[-1]),
do_lower_case='uncased' in opt.pretrained_bert)
else:
Expand Down
8 changes: 4 additions & 4 deletions pyabsa/core/tad/prediction/tad_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,12 @@ def __init__(self, model_arg=None, cal_perplexity=False, **kwargs):

with open(config_path, mode='rb') as f:
self.opt = pickle.load(f)
self.opt.device = get_device(kwargs.pop('auto_device', True))[0]
self.opt.device = get_device(kwargs.get('auto_device', True))[0]

if state_dict_path or model_path:
if hasattr(BERTTADModelList, self.opt.model.__name__):
if state_dict_path:
if kwargs.pop('offline', False):
if kwargs.get('offline', False):
self.bert = AutoModel.from_pretrained(
find_cwd_dir(self.opt.pretrained_bert.split('/')[-1]))
else:
Expand Down Expand Up @@ -172,7 +172,7 @@ def __init__(self, model_arg=None, cal_perplexity=False, **kwargs):

self.tokenizer = tokenizer

if kwargs.pop('verbose', False):
if kwargs.get('verbose', False):
print('Config used in Training:')
print_args(self.opt)

Expand All @@ -184,7 +184,7 @@ def __init__(self, model_arg=None, cal_perplexity=False, **kwargs):
raise KeyError('The checkpoint you are loading is not from classifier model.')

self.infer_dataloader = None
self.opt.eval_batch_size = kwargs.pop('eval_batch_size', 128)
self.opt.eval_batch_size = kwargs.get('eval_batch_size', 128)

# if self.opt.seed is not None:
# random.seed(self.opt.seed)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

class Tokenizer4Pretraining:
def __init__(self, max_seq_len, opt, **kwargs):
if kwargs.pop('offline', False):
if kwargs.get('offline', False):
self.tokenizer = AutoTokenizer.from_pretrained(find_cwd_dir(opt.pretrained_bert.split('/')[-1]),
do_lower_case='uncased' in opt.pretrained_bert)
else:
Expand Down
8 changes: 4 additions & 4 deletions pyabsa/core/tc/prediction/text_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,12 @@ def __init__(self, model_arg=None, cal_perplexity=False, **kwargs):

with open(config_path, mode='rb') as f:
self.opt = pickle.load(f)
self.opt.device = get_device(kwargs.pop('auto_device', True))[0]
self.opt.device = get_device(kwargs.get('auto_device', True))[0]

if state_dict_path or model_path:
if hasattr(BERTTCModelList, self.opt.model.__name__):
if state_dict_path:
if kwargs.pop('offline', False):
if kwargs.get('offline', False):
self.bert = AutoModel.from_pretrained(
find_cwd_dir(self.opt.pretrained_bert.split('/')[-1]))
else:
Expand Down Expand Up @@ -121,7 +121,7 @@ def __init__(self, model_arg=None, cal_perplexity=False, **kwargs):

self.tokenizer = tokenizer

if kwargs.pop('verbose', False):
if kwargs.get('verbose', False):
print('Config used in Training:')
print_args(self.opt)

Expand All @@ -139,7 +139,7 @@ def __init__(self, model_arg=None, cal_perplexity=False, **kwargs):
self.dataset = GloVeTCDataset(tokenizer=self.tokenizer, opt=self.opt)

self.infer_dataloader = None
self.opt.eval_batch_size = kwargs.pop('eval_batch_size', 128)
self.opt.eval_batch_size = kwargs.get('eval_batch_size', 128)

# if self.opt.seed is not None:
# random.seed(self.opt.seed)
Expand Down

0 comments on commit 64a917b

Please sign in to comment.