diff --git a/classifier_data_lib.py b/classifier_data_lib.py index 7aaabeb..4657766 100644 --- a/classifier_data_lib.py +++ b/classifier_data_lib.py @@ -308,6 +308,345 @@ def _create_examples(self, lines, set_type): InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) return examples +class MisMnliProcessor(MnliProcessor): + """Processor for the Mismatched MultiNLI data set (GLUE version).""" + + @staticmethod + def get_processor_name(): + """See base class.""" + return "MISMNLI" + + def get_dev_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "MNLI", "dev_mismatched.tsv")), + "dev") + + def get_test_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "MNLI", "test_mismatched.tsv")), + "test") + +class Sst2Processor(DataProcessor): + """Processor for the SST-2 data set (GLUE version).""" + + def get_train_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "SST-2", "train.tsv")), "train") + + def get_dev_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "SST-2", "dev.tsv")), "dev") + + def get_test_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "SST-2", "test.tsv")), "test") + + def get_labels(self): + """See base class.""" + return ["0", "1"] + + @staticmethod + def get_processor_name(): + """See base class.""" + return "SST2" + + def _create_examples(self, lines, set_type): + """Creates examples for the training and dev sets.""" + examples = [] + for (i, line) in enumerate(lines): + if i == 0: + continue + if set_type != "test": + guid = "%s-%s" % (set_type, i) + text_a = tokenization.preprocess_text(line[0],lower=FLAGS.do_lower_case) + label = tokenization.preprocess_text(line[1],lower=FLAGS.do_lower_case) + else: + guid = tokenization.preprocess_text(line[0],lower=FLAGS.do_lower_case) + # guid = "%s-%s" % (set_type, line[0]) + text_a = tokenization.preprocess_text(line[1],lower=FLAGS.do_lower_case) + label = "0" + examples.append( + InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) + return examples + + +class StsbProcessor(DataProcessor): + """Processor for the STS-B data set (GLUE version).""" + + def get_train_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "STS-B", "train.tsv")), "train") + + def get_dev_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "STS-B", "dev.tsv")), "dev") + + def get_test_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "STS-B", "test.tsv")), "test") + + def get_labels(self): + """See base class.""" + return [None] + + @staticmethod + def get_processor_name(): + """See base class.""" + return "STSB" + + def _create_examples(self, lines, set_type): + """Creates examples for the training and dev sets.""" + examples = [] + for (i, line) in enumerate(lines): + if i == 0: + continue + guid = tokenization.preprocess_text(line[0],lower=FLAGS.do_lower_case) + # guid = "%s-%s" % (set_type, line[0]) + text_a = tokenization.preprocess_text(line[7],lower=FLAGS.do_lower_case) + text_b = tokenization.preprocess_text(line[8],lower=FLAGS.do_lower_case) + if set_type != "test": + label = float(line[-1]) + else: + label = 0 + examples.append( + InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) + return examples + + +class QqpProcessor(DataProcessor): + """Processor for the QQP data set (GLUE version).""" + + def get_train_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "QQP", "train.tsv")), "train") + + def get_dev_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "QQP", "dev.tsv")), "dev") + + def get_test_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "QQP", "test.tsv")), "test") + + def get_labels(self): + """See base class.""" + return ["0", "1"] + + @staticmethod + def get_processor_name(): + """See base class.""" + return "QQP" + + def _create_examples(self, lines, set_type): + """Creates examples for the training and dev sets.""" + examples = [] + for (i, line) in enumerate(lines): + if i == 0: + continue + guid = line[0] + # guid = "%s-%s" % (set_type, line[0]) + if set_type != "test": + try: + text_a = tokenization.preprocess_text(line[3],lower=FLAGS.do_lower_case) + text_b = tokenization.preprocess_text(line[4],lower=FLAGS.do_lower_case) + label = tokenization.preprocess_text(line[5],lower=FLAGS.do_lower_case) + except IndexError: + continue + else: + text_a = tokenization.preprocess_text(line[1],lower=FLAGS.do_lower_case) + text_b = tokenization.preprocess_text(line[2],lower=FLAGS.do_lower_case) + label = "0" + examples.append( + InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) + return examples + + +class QnliProcessor(DataProcessor): + """Processor for the QNLI data set (GLUE version).""" + + def get_train_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "QNLI", "train.tsv")), "train") + + def get_dev_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "QNLI", "dev.tsv")), + "dev_matched") + + def get_test_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "QNLI", "test.tsv")), + "test_matched") + + def get_labels(self): + """See base class.""" + return ["entailment", "not_entailment"] + + @staticmethod + def get_processor_name(): + """See base class.""" + return "QNLI" + + def _create_examples(self, lines, set_type): + """Creates examples for the training and dev sets.""" + examples = [] + for (i, line) in enumerate(lines): + if i == 0: + continue + guid = tokenization.preprocess_text(line[0],lower=FLAGS.do_lower_case) + # guid = "%s-%s" % (set_type, line[0]) + text_a = tokenization.preprocess_text(line[1],lower=FLAGS.do_lower_case) + text_b = tokenization.preprocess_text(line[2],lower=FLAGS.do_lower_case) + if set_type == "test_matched": + label = "entailment" + else: + label = tokenization.preprocess_text(line[-1],lower=FLAGS.do_lower_case) + examples.append( + InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) + return examples + + +class RteProcessor(DataProcessor): + """Processor for the RTE data set (GLUE version).""" + + def get_train_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "RTE", "train.tsv")), "train") + + def get_dev_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "RTE", "dev.tsv")), "dev") + + def get_test_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "RTE", "test.tsv")), "test") + + def get_labels(self): + """See base class.""" + return ["entailment", "not_entailment"] + + @staticmethod + def get_processor_name(): + """See base class.""" + return "RTE" + + def _create_examples(self, lines, set_type): + """Creates examples for the training and dev sets.""" + examples = [] + for (i, line) in enumerate(lines): + if i == 0: + continue + guid = tokenization.preprocess_text(line[0],lower=FLAGS.do_lower_case) + # guid = "%s-%s" % (set_type, line[0]) + text_a = tokenization.preprocess_text(line[1],lower=FLAGS.do_lower_case) + text_b = tokenization.preprocess_text(line[2],lower=FLAGS.do_lower_case) + if set_type == "test": + label = "entailment" + else: + label = tokenization.preprocess_text(line[-1],lower=FLAGS.do_lower_case) + examples.append( + InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) + return examples + + +class WnliProcessor(DataProcessor): + """Processor for the WNLI data set (GLUE version).""" + + def get_train_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "WNLI", "train.tsv")), "train") + + def get_dev_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "WNLI", "dev.tsv")), "dev") + + def get_test_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "WNLI", "test.tsv")), "test") + + def get_labels(self): + """See base class.""" + return ["0", "1"] + + @staticmethod + def get_processor_name(): + """See base class.""" + return "WNLI" + + def _create_examples(self, lines, set_type): + """Creates examples for the training and dev sets.""" + examples = [] + for (i, line) in enumerate(lines): + if i == 0: + continue + guid = tokenization.preprocess_text(line[0],lower=FLAGS.do_lower_case) + # guid = "%s-%s" % (set_type, line[0]) + text_a = tokenization.preprocess_text(line[1],lower=FLAGS.do_lower_case) + text_b = tokenization.preprocess_text(line[2],lower=FLAGS.do_lower_case) + if set_type != "test": + label = tokenization.preprocess_text(line[-1],lower=FLAGS.do_lower_case) + else: + label = "0" + examples.append( + InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) + return examples + +class AXProcessor(DataProcessor): + """Processor for the AX data set (GLUE version).""" + + def get_test_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "diagnostic", "diagnostic.tsv")), + "test") + + def get_labels(self): + """See base class.""" + return ["contradiction", "entailment", "neutral"] + + @staticmethod + def get_processor_name(): + """See base class.""" + return "AX" + + def _create_examples(self, lines, set_type): + """Creates examples for the training and dev sets.""" + examples = [] + for (i, line) in enumerate(lines): + if i == 0: + continue + # Note(mingdachen): We will rely on this guid for GLUE submission. + guid = tokenization.preprocess_text(line[0],lower=FLAGS.do_lower_case) + text_a = tokenization.preprocess_text(line[1],lower=FLAGS.do_lower_case) + text_b = tokenization.preprocess_text(line[2],lower=FLAGS.do_lower_case) + if set_type == "test": + label = "contradiction" + else: + label = tokenization.preprocess_text(line[-1],lower=FLAGS.do_lower_case) + examples.append( + InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) + return examples + def convert_single_example(ex_index, example, label_list, max_seq_length, tokenizer): """Converts a single `InputExample` into a single `InputFeatures`.""" diff --git a/create_finetuning_data.py b/create_finetuning_data.py index 186d40b..097efa8 100644 --- a/create_finetuning_data.py +++ b/create_finetuning_data.py @@ -38,7 +38,8 @@ "for the task.") flags.DEFINE_enum("classification_task_name", "MNLI", - ["CoLA", "MNLI", "MRPC", "XNLI"], + ["COLA", "STS", "SST", "MNLI", "QNLI", + "QQP", "RTE", "MRPC", "WNLI", "XNLI", ], "The name of the task to train ALBERT classifier.") # ALBERT Squad task specific flags. @@ -92,51 +93,57 @@ def generate_classifier_dataset(): - """Generates classifier dataset and returns input meta data.""" - assert FLAGS.input_data_dir and FLAGS.classification_task_name - - processors = { - "cola": classifier_data_lib.ColaProcessor, - "mnli": classifier_data_lib.MnliProcessor, - "mrpc": classifier_data_lib.MrpcProcessor, - "xnli": classifier_data_lib.XnliProcessor, - } - task_name = FLAGS.classification_task_name.lower() - if task_name not in processors: - raise ValueError("Task not found: %s" % (task_name)) - processor = processors[task_name]() - return classifier_data_lib.generate_tf_record_from_data_file( - processor, - FLAGS.input_data_dir, - FLAGS.spm_model_file, - train_data_output_path=FLAGS.train_data_output_path, - eval_data_output_path=FLAGS.eval_data_output_path, - max_seq_length=FLAGS.max_seq_length, - do_lower_case=FLAGS.do_lower_case) + """Generates classifier dataset and returns input meta data.""" + assert FLAGS.input_data_dir and FLAGS.classification_task_name + + processors = { + "cola": classifier_data_lib.ColaProcessor, + "sts": classifier_data_lib.StsbProcessor, + "sst": classifier_data_lib.Sst2Processor, + "mnli": classifier_data_lib.MnliProcessor, + "qnli": classifier_data_lib.QnliProcessor, + "qqp": classifier_data_lib.QqpProcessor, + "rte": classifier_data_lib.RteProcessor, + "mrpc": classifier_data_lib.MrpcProcessor, + "wnli": classifier_data_lib.WnliProcessor, + "xnli": classifier_data_lib.XnliProcessor, + } + task_name = FLAGS.classification_task_name.lower() + if task_name not in processors: + raise ValueError("Task not found: %s" % (task_name)) + processor = processors[task_name]() + return classifier_data_lib.generate_tf_record_from_data_file( + processor, + FLAGS.input_data_dir, + FLAGS.spm_model_file, + train_data_output_path=FLAGS.train_data_output_path, + eval_data_output_path=FLAGS.eval_data_output_path, + max_seq_length=FLAGS.max_seq_length, + do_lower_case=FLAGS.do_lower_case) def generate_squad_dataset(): - """Generates squad training dataset and returns input meta data.""" - assert FLAGS.squad_data_file - return squad_lib.generate_tf_record_from_json_file( - FLAGS.squad_data_file, FLAGS.spm_model_file, FLAGS.train_data_output_path, - FLAGS.max_seq_length, FLAGS.do_lower_case, FLAGS.max_query_length, - FLAGS.doc_stride, FLAGS.version_2_with_negative) + """Generates squad training dataset and returns input meta data.""" + assert FLAGS.squad_data_file + return squad_lib.generate_tf_record_from_json_file( + FLAGS.squad_data_file, FLAGS.spm_model_file, FLAGS.train_data_output_path, + FLAGS.max_seq_length, FLAGS.do_lower_case, FLAGS.max_query_length, + FLAGS.doc_stride, FLAGS.version_2_with_negative) def main(_): - logging.set_verbosity(logging.INFO) - if FLAGS.fine_tuning_task_type == "classification": - input_meta_data = generate_classifier_dataset() - else: - input_meta_data = generate_squad_dataset() + logging.set_verbosity(logging.INFO) + if FLAGS.fine_tuning_task_type == "classification": + input_meta_data = generate_classifier_dataset() + else: + input_meta_data = generate_squad_dataset() - with tf.io.gfile.GFile(FLAGS.meta_data_file_path, "w") as writer: - writer.write(json.dumps(input_meta_data, indent=4) + "\n") + with tf.io.gfile.GFile(FLAGS.meta_data_file_path, "w") as writer: + writer.write(json.dumps(input_meta_data, indent=4) + "\n") if __name__ == "__main__": - flags.mark_flag_as_required("spm_model_file") - flags.mark_flag_as_required("train_data_output_path") - flags.mark_flag_as_required("meta_data_file_path") - app.run(main) + flags.mark_flag_as_required("spm_model_file") + flags.mark_flag_as_required("train_data_output_path") + flags.mark_flag_as_required("meta_data_file_path") + app.run(main)