Skip to content

Commit e1e83cf

Browse files
authored
Use black to reformat the project (#474)
* black * some format updates * fix mypy ignore * black format * ignore test output files * ignore source line
1 parent 7e94fee commit e1e83cf

File tree

236 files changed

+9854
-6731
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

236 files changed

+9854
-6731
lines changed

.github/workflows/main.yml

+4-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ jobs:
1616
python-version: [ 3.6, 3.7 ]
1717
torch-version: [ 1.5.0, 1.6.0 ]
1818
tensorflow-version: [ 1.15.0 ]
19-
2019
steps:
2120
- uses: actions/checkout@v2
2221
- name: Set up Python ${{ matrix.python-version }}
@@ -35,8 +34,11 @@ jobs:
3534
run: |
3635
python -m pip install --progress-bar off --upgrade pip
3736
pip install --progress-bar off Django django-guardian
38-
pip install --progress-bar off pylint==2.6.0 flake8==3.8.2 mypy==0.790 pytest==5.1.3
37+
pip install --progress-bar off pylint==2.6.0 flake8==3.8.2 mypy==0.790 pytest==5.1.3 black==20.8b1
3938
pip install --progress-bar off coverage codecov
39+
- name: Format check with Black
40+
run: |
41+
black --line-length 80 --check forte/
4042
- name: Obtain Stave Database Examples
4143
run: |
4244
git clone https://github.com/asyml/stave.git

.pylintrc

+1-1
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ logging-modules=logging
301301
[FORMAT]
302302

303303
# Maximum number of characters on a single line.
304-
max-line-length=80
304+
max-line-length=100
305305

306306
# Regexp for a line that is allowed to be longer than the limit.
307307
# This regex matches URLs and link anchors.

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
[![Documentation Status](https://readthedocs.org/projects/asyml-forte/badge/?version=latest)](https://asyml-forte.readthedocs.io/en/latest/?badge=latest)
1010
[![License](https://img.shields.io/badge/license-Apache%202.0-blue.svg)](https://github.com/asyml/forte/blob/master/LICENSE)
1111
[![Chat](http://img.shields.io/badge/gitter.im-asyml/forte-blue.svg)](https://gitter.im/asyml/community)
12+
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
1213

1314
**Forte** is a toolkit for building Natural Language Processing pipelines,
1415
featuring cross-task interaction, adaptable data-model interfaces and composable

examples/chatbot/chatbot_example.py

+40-25
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@
1616
from termcolor import colored
1717
import torch
1818

19-
from forte.nltk import (
20-
NLTKSentenceSegmenter, NLTKWordTokenizer, NLTKPOSTagger)
19+
from forte.nltk import NLTKSentenceSegmenter, NLTKWordTokenizer, NLTKPOSTagger
2120
from forte.common.configuration import Config
2221
from forte.data.multi_pack import MultiPack
2322
from forte.data.readers import MultiPackTerminalReader
@@ -36,30 +35,38 @@ def setup(config: Config) -> Pipeline:
3635
resource = Resources()
3736
query_pipeline = Pipeline[MultiPack](resource=resource)
3837
query_pipeline.set_reader(
39-
reader=MultiPackTerminalReader(), config=config.reader)
38+
reader=MultiPackTerminalReader(), config=config.reader
39+
)
4040
query_pipeline.add(
41-
component=MicrosoftBingTranslator(), config=config.translator)
41+
component=MicrosoftBingTranslator(), config=config.translator
42+
)
4243
query_pipeline.add(
43-
component=BertBasedQueryCreator(), config=config.query_creator)
44-
query_pipeline.add(
45-
component=SearchProcessor(), config=config.searcher)
44+
component=BertBasedQueryCreator(), config=config.query_creator
45+
)
46+
query_pipeline.add(component=SearchProcessor(), config=config.searcher)
4647

47-
top_response_pack_name = config.indexer.response_pack_name + '_0'
48+
top_response_pack_name = config.indexer.response_pack_name + "_0"
4849

4950
query_pipeline.add(
5051
component=NLTKSentenceSegmenter(),
51-
selector=NameMatchSelector(select_name=top_response_pack_name))
52+
selector=NameMatchSelector(select_name=top_response_pack_name),
53+
)
5254
query_pipeline.add(
5355
component=NLTKWordTokenizer(),
54-
selector=NameMatchSelector(select_name=top_response_pack_name))
56+
selector=NameMatchSelector(select_name=top_response_pack_name),
57+
)
5558
query_pipeline.add(
5659
component=NLTKPOSTagger(),
57-
selector=NameMatchSelector(select_name=top_response_pack_name))
60+
selector=NameMatchSelector(select_name=top_response_pack_name),
61+
)
5862
query_pipeline.add(
59-
component=SRLPredictor(), config=config.SRL,
60-
selector=NameMatchSelector(select_name=top_response_pack_name))
63+
component=SRLPredictor(),
64+
config=config.SRL,
65+
selector=NameMatchSelector(select_name=top_response_pack_name),
66+
)
6167
query_pipeline.add(
62-
component=MicrosoftBingTranslator(), config=config.back_translator)
68+
component=MicrosoftBingTranslator(), config=config.back_translator
69+
)
6370

6471
query_pipeline.initialize()
6572

@@ -87,28 +94,36 @@ def main(config: Config):
8794
resource.update(bot_utterance=[response_pack])
8895

8996
english_pack = m_pack.get_pack("pack")
90-
print(colored("English Translation of the query: ", "green"),
91-
english_pack.text, "\n")
97+
print(
98+
colored("English Translation of the query: ", "green"),
99+
english_pack.text,
100+
"\n",
101+
)
92102

93103
# Just take the first pack.
94-
pack = m_pack.get_pack(config.indexer.response_pack_name_prefix + '_0')
104+
pack = m_pack.get_pack(config.indexer.response_pack_name_prefix + "_0")
95105
print(colored("Retrieved Document", "green"), pack.text, "\n")
96-
print(colored("German Translation", "green"),
97-
m_pack.get_pack("response").text, "\n")
106+
print(
107+
colored("German Translation", "green"),
108+
m_pack.get_pack("response").text,
109+
"\n",
110+
)
98111
for sentence in pack.get(Sentence):
99112
sent_text = sentence.text
100-
print(colored("Sentence:", 'red'), sent_text, "\n")
113+
print(colored("Sentence:", "red"), sent_text, "\n")
101114

102-
print(colored("Semantic role labels:", 'red'))
115+
print(colored("Semantic role labels:", "red"))
103116
for link in pack.get(PredicateLink, sentence):
104117
parent = link.get_parent()
105118
child = link.get_child()
106-
print(f" - \"{child.text}\" is role "
107-
f"{link.arg_type} of "
108-
f"predicate \"{parent.text}\"")
119+
print(
120+
f' - "{child.text}" is role '
121+
f"{link.arg_type} of "
122+
f'predicate "{parent.text}"'
123+
)
109124
print()
110125

111-
input(colored("Press ENTER to continue...\n", 'green'))
126+
input(colored("Press ENTER to continue...\n", "green"))
112127

113128

114129
if __name__ == "__main__":

examples/chatbot/config_data.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
"sent_b_seq_len": ["int64", "stacked_tensor"],
2525
"sent_b_segment_ids": ["int64", "stacked_tensor", max_seq_length],
2626
"sentence_b": ["str", "stacked_tensor"],
27-
"label_ids": ["int64", "stacked_tensor"]
27+
"label_ids": ["int64", "stacked_tensor"],
2828
}
2929

3030
train_hparam = {
@@ -33,10 +33,10 @@
3333
"dataset": {
3434
"data_name": "data",
3535
"feature_types": feature_types,
36-
"files": "{}/train.pkl".format(pickle_data_dir)
36+
"files": "{}/train.pkl".format(pickle_data_dir),
3737
},
3838
"shuffle": True,
39-
"shuffle_buffer_size": 100
39+
"shuffle_buffer_size": 100,
4040
}
4141

4242
eval_hparam = {
@@ -45,9 +45,9 @@
4545
"dataset": {
4646
"data_name": "data",
4747
"feature_types": feature_types,
48-
"files": "{}/eval.pkl".format(pickle_data_dir)
48+
"files": "{}/eval.pkl".format(pickle_data_dir),
4949
},
50-
"shuffle": False
50+
"shuffle": False,
5151
}
5252

5353
test_hparam = {
@@ -56,7 +56,7 @@
5656
"dataset": {
5757
"data_name": "data",
5858
"feature_types": feature_types,
59-
"files": "{}/test.pkl".format(pickle_data_dir)
59+
"files": "{}/test.pkl".format(pickle_data_dir),
6060
},
61-
"shuffle": False
61+
"shuffle": False,
6262
}

examples/chatbot/create_index.py

+29-16
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,9 @@
3030
logging.basicConfig(level=logging.INFO)
3131

3232
parser = argparse.ArgumentParser()
33-
parser.add_argument("--config_data", default="config_data",
34-
help="File to read the config from")
33+
parser.add_argument(
34+
"--config_data", default="config_data", help="File to read the config from"
35+
)
3536
args = parser.parse_args()
3637

3738
config = yaml.safe_load(open("config.yml", "r"))
@@ -40,11 +41,11 @@
4041

4142

4243
class Indexer:
43-
4444
def __init__(self, model_path, torch_device=None):
4545

4646
self.bert = tx.modules.BERTEncoder(
47-
pretrained_model_name=None, hparams={"pretrained_model_name": None})
47+
pretrained_model_name=None, hparams={"pretrained_model_name": None}
48+
)
4849
self.device = torch_device
4950
self.bert.to(device=self.device)
5051

@@ -54,10 +55,16 @@ def __init__(self, model_path, torch_device=None):
5455
self.bert.load_state_dict(state_dict["bert"])
5556

5657
self.tokenizer = tx.data.BERTTokenizer(
57-
pretrained_model_name="bert-base-uncased")
58+
pretrained_model_name="bert-base-uncased"
59+
)
5860

59-
self.index = EmbeddingBasedIndexer(config={
60-
"index_type": "GpuIndexFlatIP", "dim": 768, "device": "gpu0"})
61+
self.index = EmbeddingBasedIndexer(
62+
config={
63+
"index_type": "GpuIndexFlatIP",
64+
"dim": 768,
65+
"device": "gpu0",
66+
}
67+
)
6168

6269
@torch.no_grad()
6370
def create_index(self):
@@ -67,9 +74,9 @@ def create_index(self):
6774
"dataset": {
6875
"data_name": "data",
6976
"feature_types": config_data.feature_types,
70-
"files": ["data/train.pkl", "data/eval.pkl", "data/test.pkl"]
77+
"files": ["data/train.pkl", "data/eval.pkl", "data/test.pkl"],
7178
},
72-
"shuffle": False
79+
"shuffle": False,
7380
}
7481

7582
dataset = tx.data.RecordData(hparams=hparams, device=self.device)
@@ -79,23 +86,29 @@ def create_index(self):
7986
for idx, batch in enumerate(data_iterator):
8087
ids = range(start, start + len(batch))
8188
text = batch["sentence_b"]
82-
output, _ = self.bert(inputs=batch["sent_b_input_ids"],
83-
sequence_length=batch["sent_b_seq_len"],
84-
segment_ids=batch["sent_b_segment_ids"])
89+
output, _ = self.bert(
90+
inputs=batch["sent_b_input_ids"],
91+
sequence_length=batch["sent_b_seq_len"],
92+
segment_ids=batch["sent_b_segment_ids"],
93+
)
8594
cls_tokens = output[:, 0, :] # CLS token is first token
8695
self.index.add(vectors=cls_tokens, meta_data=dict(zip(ids, text)))
8796

8897
start += len(batch)
8998

9099
if (idx + 1) % 50 == 0:
91-
logging.info("Completed %s batches of size %s", idx + 1,
92-
config.indexer.batch_size)
100+
logging.info(
101+
"Completed %s batches of size %s",
102+
idx + 1,
103+
config.indexer.batch_size,
104+
)
93105

94106
self.index.save(path=config.indexer.model_dir)
95107

96108

97109
if __name__ == "__main__":
98110
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
99-
indexer = Indexer(model_path="model/chatbot_model.ckpt",
100-
torch_device=device)
111+
indexer = Indexer(
112+
model_path="model/chatbot_model.ckpt", torch_device=device
113+
)
101114
indexer.create_index()

examples/chatbot/data_utils.py

+14-10
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ def split_train_eval_test(file_name):
2727
conversation = None
2828

2929
for line in text_file:
30-
pair = line.rstrip('\n').split('\t')
31-
num, question = pair[0].split(' ', 1)
30+
pair = line.rstrip("\n").split("\t")
31+
num, question = pair[0].split(" ", 1)
3232
answer = pair[1]
3333

3434
if num == "1":
@@ -49,8 +49,8 @@ def split_train_eval_test(file_name):
4949
num_test = 500
5050

5151
train_data = text_data[0:num_train]
52-
val_data = text_data[num_train:num_train + num_val]
53-
test_data = text_data[num_train + num_val:num_train + num_val + num_test]
52+
val_data = text_data[num_train : num_train + num_val]
53+
test_data = text_data[num_train + num_val : num_train + num_val + num_test]
5454

5555
return train_data, val_data, test_data
5656

@@ -88,14 +88,17 @@ def _create_conv_with_history(conv, num_qa):
8888

8989
new_text_data = []
9090
for i, _ in enumerate(conv):
91-
history = conv[max(i - num_qa, 0):i]
91+
history = conv[max(i - num_qa, 0) : i]
9292
current_qa = conv[i]
9393

9494
if history:
95-
qa_with_history = [sentence for qa in history for sentence in
96-
qa] + current_qa
97-
qa_with_history = [' '.join(qa_with_history[:-1]),
98-
qa_with_history[-1]]
95+
qa_with_history = [
96+
sentence for qa in history for sentence in qa
97+
] + current_qa
98+
qa_with_history = [
99+
" ".join(qa_with_history[:-1]),
100+
qa_with_history[-1],
101+
]
99102
else:
100103
qa_with_history = current_qa
101104

@@ -121,7 +124,8 @@ def create_dataset_with_history(conversations, num_line=2):
121124

122125
for conversation in conversations:
123126
conversation_with_history = _create_conv_with_history(
124-
conversation, num_line)
127+
conversation, num_line
128+
)
125129
proc_text_data.extend(conversation_with_history)
126130

127131
return proc_text_data

0 commit comments

Comments
 (0)