Skip to content

Commit

Permalink
squad preprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
kamalkraj committed Nov 2, 2019
1 parent 9ed3a03 commit acb566b
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 15 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -120,4 +120,7 @@ glue_data/*
*_meta_data

# ouptut dir
*_OUT/*
*_OUT/*
*_out_*/*
# squad
SQuAD/*
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,24 @@ End of sequence

- WIP

#### Training Data Preparation
```bash
export SQUAD_DIR=SQuAD
export SQUAD_VERSION=v1.1
export ALBERT_DIR=large
export OUTPUT_DIR=squad_out_${SQUAD_VERSION}
mkdir $OUTPUT_DIR

python create_finetuning_data.py \
--squad_data_file=${SQUAD_DIR}/train-${SQUAD_VERSION}.json \
--spm_model_file=large/vocab/30k-clean.model \
--train_data_output_path=${OUTPUT_DIR}/squad_${SQUAD_VERSION}_train.tf_record \
--meta_data_file_path=${OUTPUT_DIR}/squad_${SQUAD_VERSION}_meta_data
--fine_tuning_task_type=squad \
--max_seq_length=384
```



### Multi-GPU training

Expand Down
9 changes: 4 additions & 5 deletions create_finetuning_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,13 @@
# ==============================================================================
"""BERT finetuning task dataset generator."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import, division, print_function

import json

from absl import app
from absl import flags
import tensorflow as tf
from absl import app, flags, logging

import classifier_data_lib
import squad_lib

Expand Down Expand Up @@ -127,6 +125,7 @@ def generate_squad_dataset():


def main(_):
logging.set_verbosity(logging.INFO)
if FLAGS.fine_tuning_task_type == "classification":
input_meta_data = generate_classifier_dataset()
else:
Expand Down
20 changes: 12 additions & 8 deletions squad_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
f = np.zeros((max_n, max_m), dtype=np.float32)

for (example_index, example) in enumerate(examples):

if example_index % 100 == 0:
logging.info("Converting {}/{} pos {} neg {}".format(
example_index, len(examples), cnt_pos, cnt_neg))
Expand All @@ -254,20 +254,24 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
example.paragraph_text, lower=FLAGS.do_lower_case),
return_unicode=False)

para_tokens_ = []
for para_token in para_tokens:
if type(para_token) == bytes:
para_token = para_token.decode("utf-8")
para_tokens_.append(para_token)
para_tokens = para_tokens_

chartok_to_tok_index = []
tok_start_to_chartok_index = []
tok_end_to_chartok_index = []
char_cnt = 0
for i, token in enumerate(para_tokens):
new_token = six.ensure_binary(token).replace(
tokenization.SPIECE_UNDERLINE, b" ")
chartok_to_tok_index.extend([i] * len(new_token))
chartok_to_tok_index.extend([i] * len(token))
tok_start_to_chartok_index.append(char_cnt)
char_cnt += len(new_token)
char_cnt += len(token)
tok_end_to_chartok_index.append(char_cnt - 1)

tok_cat_text = "".join(para_tokens).replace(
tokenization.SPIECE_UNDERLINE.decode("utf-8"), " ")
tok_cat_text = "".join(para_tokens).replace(tokenization.SPIECE_UNDERLINE.decode("utf-8"), " ")
n, m = len(paragraph_text), len(tok_cat_text)

if n > max_n or m > max_m:
Expand Down Expand Up @@ -855,7 +859,7 @@ def generate_tf_record_from_json_file(input_file_path,
train_writer.close()

meta_data = {
"task_type": "bert_squad",
"task_type": "albert_squad",
"train_data_size": number_of_examples,
"max_seq_length": max_seq_length,
"max_query_length": max_query_length,
Expand Down
4 changes: 3 additions & 1 deletion tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def preprocess_text(inputs, remove_space=True, lower=False):
if remove_space:
outputs = " ".join(inputs.strip().split())

outputs = outputs.replace('``', '"').replace("''", '"')

if six.PY2 and isinstance(outputs, str):
try:
outputs = six.ensure_text(outputs, "utf-8")
Expand Down Expand Up @@ -494,4 +496,4 @@ def _is_punctuation(char):
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False
return False

0 comments on commit acb566b

Please sign in to comment.