diff --git a/README.md b/README.md index 20d48ea..06c9de8 100644 --- a/README.md +++ b/README.md @@ -43,12 +43,12 @@ following sections: export GLUE_DIR=glue_data/ export ALBERT_DIR=large/ -export TASK_NAME=COLA +export TASK_NAME=CoLA export OUTPUT_DIR=cola_processed mkdir $OUTPUT_DIR python create_finetuning_data.py \ - --input_data_dir=${GLUE_DIR}/${TASK_NAME}/ \ + --input_data_dir=${GLUE_DIR}/ \ --spm_model_file=${ALBERT_DIR}/vocab/30k-clean.model \ --train_data_output_path=${OUTPUT_DIR}/${TASK_NAME}_train.tf_record \ --eval_data_output_path=${OUTPUT_DIR}/${TASK_NAME}_eval.tf_record \ diff --git a/classifier_data_lib.py b/classifier_data_lib.py index 4657766..f917288 100644 --- a/classifier_data_lib.py +++ b/classifier_data_lib.py @@ -173,18 +173,18 @@ class MnliProcessor(DataProcessor): def get_train_examples(self, data_dir): """See base class.""" return self._create_examples( - self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") + self._read_tsv(os.path.join(data_dir, "MNLI", "train.tsv")), "train") def get_dev_examples(self, data_dir): """See base class.""" return self._create_examples( - self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), + self._read_tsv(os.path.join(data_dir, "MNLI", "dev_matched.tsv")), "dev_matched") def get_test_examples(self, data_dir): """See base class.""" return self._create_examples( - self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test") + self._read_tsv(os.path.join(data_dir, "MNLI", "test_matched.tsv")), "test") def get_labels(self): """See base class.""" @@ -220,17 +220,17 @@ class MrpcProcessor(DataProcessor): def get_train_examples(self, data_dir): """See base class.""" return self._create_examples( - self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") + self._read_tsv(os.path.join(data_dir, "MRPC", "train.tsv")), "train") def get_dev_examples(self, data_dir): """See base class.""" return self._create_examples( - self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") + self._read_tsv(os.path.join(data_dir, "MRPC", "dev.tsv")), "dev") def get_test_examples(self, data_dir): """See base class.""" return self._create_examples( - self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") + self._read_tsv(os.path.join(data_dir, "MRPC", "test.tsv")), "test") def get_labels(self): """See base class.""" @@ -266,17 +266,17 @@ class ColaProcessor(DataProcessor): def get_train_examples(self, data_dir): """See base class.""" return self._create_examples( - self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") + self._read_tsv(os.path.join(data_dir, "CoLA", "train.tsv")), "train") def get_dev_examples(self, data_dir): """See base class.""" return self._create_examples( - self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") + self._read_tsv(os.path.join(data_dir, "CoLA", "dev.tsv")), "dev") def get_test_examples(self, data_dir): """See base class.""" return self._create_examples( - self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") + self._read_tsv(os.path.join(data_dir, "CoLA", "test.tsv")), "test") def get_labels(self): """See base class."""