Skip to content

Commit

Permalink
Formatter and linter 119
Browse files Browse the repository at this point in the history
  • Loading branch information
flaviussn committed Jan 20, 2020
1 parent fecd7b7 commit ec73d6b
Show file tree
Hide file tree
Showing 11 changed files with 1,541 additions and 726 deletions.
2 changes: 1 addition & 1 deletion simpletransformers/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
name = "simpletransformers"
name = "simpletransformers"
4 changes: 3 additions & 1 deletion simpletransformers/classification/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from simpletransformers.classification.classification_model import ClassificationModel
from simpletransformers.classification.multi_label_classification_model import MultiLabelClassificationModel
from simpletransformers.classification.multi_label_classification_model import (
MultiLabelClassificationModel,
)
719 changes: 502 additions & 217 deletions simpletransformers/classification/classification_model.py

Large diffs are not rendered by default.

136 changes: 109 additions & 27 deletions simpletransformers/classification/classification_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,23 @@ def convert_example_to_feature(
cls_token_segment_id=1,
pad_token_segment_id=0,
mask_padding_with_zero=True,
sep_token_extra=False
sep_token_extra=False,
):
example, max_seq_length, tokenizer, output_mode, cls_token_at_end, cls_token, sep_token, cls_token_segment_id, pad_on_left, pad_token_segment_id, sep_token_extra, multi_label, stride = example_row
(
example,
max_seq_length,
tokenizer,
output_mode,
cls_token_at_end,
cls_token,
sep_token,
cls_token_segment_id,
pad_on_left,
pad_token_segment_id,
sep_token_extra,
multi_label,
stride,
) = example_row

tokens_a = tokenizer.tokenize(example.text_a)

Expand All @@ -90,7 +104,7 @@ def convert_example_to_feature(
# Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa.
special_tokens_count = 3 if sep_token_extra else 2
if len(tokens_a) > max_seq_length - special_tokens_count:
tokens_a = tokens_a[:(max_seq_length - special_tokens_count)]
tokens_a = tokens_a[: (max_seq_length - special_tokens_count)]

# The convention in BERT is:
# (a) For sequence pairs:
Expand Down Expand Up @@ -134,11 +148,15 @@ def convert_example_to_feature(
padding_length = max_seq_length - len(input_ids)
if pad_on_left:
input_ids = ([pad_token] * padding_length) + input_ids
input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask
input_mask = (
[0 if mask_padding_with_zero else 1] * padding_length
) + input_mask
segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
else:
input_ids = input_ids + ([pad_token] * padding_length)
input_mask = input_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
input_mask = input_mask + (
[0 if mask_padding_with_zero else 1] * padding_length
)
segment_ids = segment_ids + ([pad_token_segment_id] * padding_length)

assert len(input_ids) == max_seq_length
Expand All @@ -152,14 +170,14 @@ def convert_example_to_feature(
# else:
# raise KeyError(output_mode)

if output_mode == 'regression':
label_id = float(example.label)
# if output_mode == "regression":
# label_id = float(example.label)

return InputFeatures(
input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
label_id=example.label
label_id=example.label,
)


Expand All @@ -173,7 +191,21 @@ def convert_example_to_feature_sliding_window(
mask_padding_with_zero=True,
sep_token_extra=False,
):
example, max_seq_length, tokenizer, output_mode, cls_token_at_end, cls_token, sep_token, cls_token_segment_id, pad_on_left, pad_token_segment_id, sep_token_extra, multi_label, stride = example_row
(
example,
max_seq_length,
tokenizer,
output_mode,
cls_token_at_end,
cls_token,
sep_token,
cls_token_segment_id,
pad_on_left,
pad_token_segment_id,
sep_token_extra,
multi_label,
stride,
) = example_row

if stride < 1:
stride = int(max_seq_length * stride)
Expand All @@ -183,14 +215,17 @@ def convert_example_to_feature_sliding_window(

tokens_a = tokenizer.tokenize(example.text_a)

special_tokens_count = 3 if sep_token_extra else 2
if len(tokens_a) > bucket_size:
token_sets = [tokens_a[i:i + bucket_size] for i in range(0, len(tokens_a), stride)]
token_sets = [
tokens_a[i : i + bucket_size] for i in range(0, len(tokens_a), stride)
]
else:
token_sets.append(tokens_a)

if example.text_b:
raise ValueError("Sequence pair tasks not implemented for sliding window tokenization.")
raise ValueError(
"Sequence pair tasks not implemented for sliding window tokenization."
)

# The convention in BERT is:
# (a) For sequence pairs:
Expand Down Expand Up @@ -233,11 +268,15 @@ def convert_example_to_feature_sliding_window(
padding_length = max_seq_length - len(input_ids)
if pad_on_left:
input_ids = ([pad_token] * padding_length) + input_ids
input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask
input_mask = (
[0 if mask_padding_with_zero else 1] * padding_length
) + input_mask
segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
else:
input_ids = input_ids + ([pad_token] * padding_length)
input_mask = input_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
input_mask = input_mask + (
[0 if mask_padding_with_zero else 1] * padding_length
)
segment_ids = segment_ids + ([pad_token_segment_id] * padding_length)

assert len(input_ids) == max_seq_length
Expand All @@ -256,7 +295,7 @@ def convert_example_to_feature_sliding_window(
input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
label_id=example.label
label_id=example.label,
)
)

Expand Down Expand Up @@ -285,7 +324,7 @@ def convert_examples_to_features(
use_multiprocessing=True,
sliding_window=False,
flatten=False,
stride=None
stride=None,
):
""" Loads a data file into a list of `InputBatch`s
`cls_token_at_end` define the location of the CLS token:
Expand All @@ -294,28 +333,71 @@ def convert_examples_to_features(
`cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
"""

examples = [(example, max_seq_length, tokenizer, output_mode, cls_token_at_end, cls_token, sep_token, cls_token_segment_id, pad_on_left, pad_token_segment_id, sep_token_extra, multi_label, stride) for example in examples]
examples = [
(
example,
max_seq_length,
tokenizer,
output_mode,
cls_token_at_end,
cls_token,
sep_token,
cls_token_segment_id,
pad_on_left,
pad_token_segment_id,
sep_token_extra,
multi_label,
stride,
)
for example in examples
]

if use_multiprocessing:
if sliding_window:
print('sliding_window enabled')
print("sliding_window enabled")
with Pool(process_count) as p:
features = list(tqdm(p.imap(convert_example_to_feature_sliding_window, examples, chunksize=500), total=len(examples), disable=silent))
features = list(
tqdm(
p.imap(
convert_example_to_feature_sliding_window,
examples,
chunksize=500,
),
total=len(examples),
disable=silent,
)
)
if flatten:
features = [feature for feature_set in features for feature in feature_set]
print(f'{len(features)} features created from {len(examples)} samples.')
features = [
feature for feature_set in features for feature in feature_set
]
print(f"{len(features)} features created from {len(examples)} samples.")
else:
with Pool(process_count) as p:
features = list(tqdm(p.imap(convert_example_to_feature, examples, chunksize=500), total=len(examples), disable=silent))
features = list(
tqdm(
p.imap(convert_example_to_feature, examples, chunksize=500),
total=len(examples),
disable=silent,
)
)
else:
if sliding_window:
print('sliding_window enabled')
features = [convert_example_to_feature_sliding_window(example) for example in tqdm(examples, disable=silent)]
print("sliding_window enabled")
features = [
convert_example_to_feature_sliding_window(example)
for example in tqdm(examples, disable=silent)
]
if flatten:
features = [feature for feature_set in features for feature in feature_set]
print(f'{len(features)} features created from {len(examples)} samples.')
features = [
feature for feature_set in features for feature in feature_set
]
print(f"{len(features)} features created from {len(examples)} samples.")
else:
features = [convert_example_to_feature(example) for example in tqdm(examples, disable=silent)]
features = [
convert_example_to_feature(example)
for example in tqdm(examples, disable=silent)
]

return features

Expand Down
Loading

0 comments on commit ec73d6b

Please sign in to comment.