diff --git a/.gitignore b/.gitignore index 29dd9bd..5488d44 100644 --- a/.gitignore +++ b/.gitignore @@ -120,4 +120,7 @@ glue_data/* *_meta_data # ouptut dir -*_OUT/* \ No newline at end of file +*_OUT/* +*_out_*/* +# squad +SQuAD/* diff --git a/README.md b/README.md index ff72274..34e89c3 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,24 @@ End of sequence - WIP +#### Training Data Preparation +```bash +export SQUAD_DIR=SQuAD +export SQUAD_VERSION=v1.1 +export ALBERT_DIR=large +export OUTPUT_DIR=squad_out_${SQUAD_VERSION} +mkdir $OUTPUT_DIR + +python create_finetuning_data.py \ +--squad_data_file=${SQUAD_DIR}/train-${SQUAD_VERSION}.json \ +--spm_model_file=large/vocab/30k-clean.model \ +--train_data_output_path=${OUTPUT_DIR}/squad_${SQUAD_VERSION}_train.tf_record \ +--meta_data_file_path=${OUTPUT_DIR}/squad_${SQUAD_VERSION}_meta_data +--fine_tuning_task_type=squad \ +--max_seq_length=384 +``` + + ### Multi-GPU training diff --git a/create_finetuning_data.py b/create_finetuning_data.py index 4e3eed0..186d40b 100644 --- a/create_finetuning_data.py +++ b/create_finetuning_data.py @@ -14,15 +14,13 @@ # ============================================================================== """BERT finetuning task dataset generator.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function import json -from absl import app -from absl import flags import tensorflow as tf +from absl import app, flags, logging + import classifier_data_lib import squad_lib @@ -127,6 +125,7 @@ def generate_squad_dataset(): def main(_): + logging.set_verbosity(logging.INFO) if FLAGS.fine_tuning_task_type == "classification": input_meta_data = generate_classifier_dataset() else: diff --git a/squad_lib.py b/squad_lib.py index 7180655..292c207 100644 --- a/squad_lib.py +++ b/squad_lib.py @@ -234,7 +234,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, f = np.zeros((max_n, max_m), dtype=np.float32) for (example_index, example) in enumerate(examples): - + if example_index % 100 == 0: logging.info("Converting {}/{} pos {} neg {}".format( example_index, len(examples), cnt_pos, cnt_neg)) @@ -254,20 +254,24 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, example.paragraph_text, lower=FLAGS.do_lower_case), return_unicode=False) + para_tokens_ = [] + for para_token in para_tokens: + if type(para_token) == bytes: + para_token = para_token.decode("utf-8") + para_tokens_.append(para_token) + para_tokens = para_tokens_ + chartok_to_tok_index = [] tok_start_to_chartok_index = [] tok_end_to_chartok_index = [] char_cnt = 0 for i, token in enumerate(para_tokens): - new_token = six.ensure_binary(token).replace( - tokenization.SPIECE_UNDERLINE, b" ") - chartok_to_tok_index.extend([i] * len(new_token)) + chartok_to_tok_index.extend([i] * len(token)) tok_start_to_chartok_index.append(char_cnt) - char_cnt += len(new_token) + char_cnt += len(token) tok_end_to_chartok_index.append(char_cnt - 1) - tok_cat_text = "".join(para_tokens).replace( - tokenization.SPIECE_UNDERLINE.decode("utf-8"), " ") + tok_cat_text = "".join(para_tokens).replace(tokenization.SPIECE_UNDERLINE.decode("utf-8"), " ") n, m = len(paragraph_text), len(tok_cat_text) if n > max_n or m > max_m: @@ -855,7 +859,7 @@ def generate_tf_record_from_json_file(input_file_path, train_writer.close() meta_data = { - "task_type": "bert_squad", + "task_type": "albert_squad", "train_data_size": number_of_examples, "max_seq_length": max_seq_length, "max_query_length": max_query_length, diff --git a/tokenization.py b/tokenization.py index abf6781..d0bb8a8 100644 --- a/tokenization.py +++ b/tokenization.py @@ -93,6 +93,8 @@ def preprocess_text(inputs, remove_space=True, lower=False): if remove_space: outputs = " ".join(inputs.strip().split()) + outputs = outputs.replace('``', '"').replace("''", '"') + if six.PY2 and isinstance(outputs, str): try: outputs = six.ensure_text(outputs, "utf-8") @@ -494,4 +496,4 @@ def _is_punctuation(char): cat = unicodedata.category(char) if cat.startswith("P"): return True - return False \ No newline at end of file + return False