diff --git a/ensemble.sh b/ensemble.sh new file mode 100644 index 0000000..e69de29 diff --git a/ensemble_by_logit.py b/ensemble_by_logit.py new file mode 100644 index 0000000..14a937c --- /dev/null +++ b/ensemble_by_logit.py @@ -0,0 +1,223 @@ +import json +import sys +import collections +from transformers.tokenization_bert import BasicTokenizer +import logging +#logging = logging.getLogger(__name__) + +def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): + """Project the tokenized prediction back to the original text.""" + + # When we created the data, we kept track of the alignment between original + # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So + # now `orig_text` contains the span of our original text corresponding to the + # span that we predicted. + # + # However, `orig_text` may contain extra characters that we don't want in + # our prediction. + # + # For example, let's say: + # pred_text = steve smith + # orig_text = Steve Smith's + # + # We don't want to return `orig_text` because it contains the extra "'s". + # + # We don't want to return `pred_text` because it's already been normalized + # (the SQuAD eval script also does punctuation stripping/lower casing but + # our tokenizer does additional normalization like stripping accent + # characters). + # + # What we really want to return is "Steve Smith". + # + # Therefore, we have to apply a semi-complicated alignment heuristic between + # `pred_text` and `orig_text` to get a character-to-character alignment. This + # can fail in certain cases in which case we just return `orig_text`. + + def _strip_spaces(text): + ns_chars = [] + ns_to_s_map = collections.OrderedDict() + for (i, c) in enumerate(text): + if c == " ": + continue + ns_to_s_map[len(ns_chars)] = i + ns_chars.append(c) + ns_text = "".join(ns_chars) + return (ns_text, ns_to_s_map) + + # We first tokenize `orig_text`, strip whitespace from the result + # and `pred_text`, and check if they are the same length. If they are + # NOT the same length, the heuristic has failed. If they are the same + # length, we assume the characters are one-to-one aligned. + tokenizer = BasicTokenizer(do_lower_case=True) + + tok_text = " ".join(tokenizer.tokenize(orig_text)) + + start_position = tok_text.find(pred_text) + if start_position == -1: + if verbose_logging: + print ("=="*10) + print (tok_text) + print("Unable to find text: '%s' in '%s'" % (pred_text, orig_text)) + #return orig_text,0,len(orig_text) + return pred_text.replace(" ",""),0,len(orig_text) + end_position = start_position + len(pred_text) - 1 + + (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text) + (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text) + + if len(orig_ns_text) != len(tok_ns_text): + if verbose_logging: + print("Length not equal after stripping spaces: '%s' vs '%s'", orig_ns_text, tok_ns_text) + #return orig_text,0,len(orig_text) + return pred_text.replace(" ",""),0,len(orig_text) + + # We then project the characters in `pred_text` back to `orig_text` using + # the character-to-character alignment. + tok_s_to_ns_map = {} + for (i, tok_index) in tok_ns_to_s_map.items(): + tok_s_to_ns_map[tok_index] = i + + orig_start_position = None + if start_position in tok_s_to_ns_map: + ns_start_position = tok_s_to_ns_map[start_position] + if ns_start_position in orig_ns_to_s_map: + orig_start_position = orig_ns_to_s_map[ns_start_position] + + if orig_start_position is None: + if verbose_logging: + print("Couldn't map start position") + #return orig_text,0,len(orig_text) + return pred_text.replace(" ",""),0,len(orig_text) + + orig_end_position = None + if end_position in tok_s_to_ns_map: + ns_end_position = tok_s_to_ns_map[end_position] + if ns_end_position in orig_ns_to_s_map: + orig_end_position = orig_ns_to_s_map[ns_end_position] + + if orig_end_position is None: + if verbose_logging: + print("Couldn't map end position") + #return orig_text,0,len(orig_text) + return pred_text.replace(" ",""),0,len(orig_text) + + output_text = orig_text[orig_start_position : (orig_end_position + 1)] + return output_text,orig_start_position,orig_end_position + 1 +def _get_best_start_indexes(logits, n_best_size): + """Get the n-best logits from a list.""" + index_and_score = sorted(logits.items(), key=lambda x: x[1][2], reverse=True) + + best_indexes = [] + for i in range(len(index_and_score)): + if i >= n_best_size: + break + best_indexes.append(index_and_score[i][0]) + return best_indexes + +def _get_best_end_indexes(logits, n_best_size): + """Get the n-best logits from a list.""" + index_and_score = sorted(logits.items(), key=lambda x: x[1], reverse=True) + + best_indexes = [] + for i in range(len(index_and_score)): + if i >= n_best_size: + break + best_indexes.append(index_and_score[i][0]) + return best_indexes +def extract_answer(info): + + start_logits = info["start_logits"] + end_logits = info["end_logits"] + + tokens = [start_logits[str(i)][1] for i in range(len(start_logits))] + + start_indexes = _get_best_start_indexes(start_logits, 10) + end_indexes = _get_best_end_indexes(end_logits, 10) + + _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name + "PrelimPrediction", ["start_index", "end_index", "start_logit", "end_logit"] + ) + + prelim_predictions = [] + + for start_index in start_indexes: + for end_index in end_indexes: + if int(end_index) < int(start_index): + continue + length = int(end_index) - int(start_index) + 1 + if length >30: + continue + prelim_predictions.append( + _PrelimPrediction( + start_index=start_index, + end_index=end_index, + start_logit=start_logits[start_index][2], + end_logit=end_logits[end_index], + ) + ) + + prelim_predictions = sorted(prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True) + _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name + "NbestPrediction", ["text", "start_logit", "end_logit","start_index", "end_index"] + ) + nbest = [] + seen_predictions = {} + final_text = "" + for pred in prelim_predictions: + if len(nbest) >= 10: + break + tok_tokens = tokens[int(pred.start_index):int(pred.end_index)+1] + orig_tokens = info["ori_tokens"][start_logits[pred.start_index][0]:start_logits[pred.end_index][0] + 1] + + tok_text = " ".join(tok_tokens) + tok_text = tok_text.replace(" ##", "") + tok_text = tok_text.replace("##", "") + + tok_text = tok_text.strip() + tok_text = " ".join(tok_text.split()) + orig_text = " ".join(orig_tokens) + + final_text,start_index,end_index = get_final_text(tok_text, orig_text, do_lower_case=False, verbose_logging=True) + break + + if final_text in seen_predictions: + continue + seen_predictions[final_text] = True + + return final_text + +def ensemble_logits(file_list): + + data_list = [json.load(open(f)) for f in file_list] + data_new = data_list[0] + + for qid in data_new.keys(): + for index in range(len(data_new[qid]["end_logits"])): + if type(data_new[qid]["end_logits"][str(index)]) in (tuple,list): + data_new[qid]["end_logits"][str(index)] = data_new[qid]["end_logits"][str(index)][2] + else: + break + + for qid in data_list[0].keys(): + + for i in range(1,len(file_list)): + assert data_list[i][qid]["ori_tokens"] == data_new[qid]["ori_tokens"] + for index in range(len(data_new[qid]["start_logits"])): + assert data_new[qid]["start_logits"][str(index)][1] == data_list[i][qid]["start_logits"][str(index)][1] + data_new[qid]["start_logits"][str(index)][2] += data_list[i][qid]["start_logits"][str(index)][2] + try: + data_new[qid]["end_logits"][str(index)] += data_list[i][qid]["end_logits"][str(index)] + except: + data_new[qid]["end_logits"][str(index)] += data_list[i][qid]["end_logits"][str(index)][2] + return data_new + +if __name__ == "__main__": + ensemble_data_files = sys.argv[1:-1] + data = ensemble_logits(ensemble_data_files) + + result = collections.OrderedDict() + for qid,logit in data.items(): + result[qid] = extract_answer(logit) + + json.dump(result,open(sys.argv[-1],"w"),ensure_ascii=False,indent=4) + diff --git a/ensemble_by_logit_multi_file.py b/ensemble_by_logit_multi_file.py new file mode 100644 index 0000000..2d83593 --- /dev/null +++ b/ensemble_by_logit_multi_file.py @@ -0,0 +1,214 @@ +import json +import sys +import collections +from transformers.tokenization_bert import BasicTokenizer +import logging +import os +#logging = logging.getLogger(__name__) + +def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): + """Project the tokenized prediction back to the original text.""" + + # When we created the data, we kept track of the alignment between original + # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So + # now `orig_text` contains the span of our original text corresponding to the + # span that we predicted. + # + # However, `orig_text` may contain extra characters that we don't want in + # our prediction. + # + # For example, let's say: + # pred_text = steve smith + # orig_text = Steve Smith's + # + # We don't want to return `orig_text` because it contains the extra "'s". + # + # We don't want to return `pred_text` because it's already been normalized + # (the SQuAD eval script also does punctuation stripping/lower casing but + # our tokenizer does additional normalization like stripping accent + # characters). + # + # What we really want to return is "Steve Smith". + # + # Therefore, we have to apply a semi-complicated alignment heuristic between + # `pred_text` and `orig_text` to get a character-to-character alignment. This + # can fail in certain cases in which case we just return `orig_text`. + + def _strip_spaces(text): + ns_chars = [] + ns_to_s_map = collections.OrderedDict() + for (i, c) in enumerate(text): + if c == " ": + continue + ns_to_s_map[len(ns_chars)] = i + ns_chars.append(c) + ns_text = "".join(ns_chars) + return (ns_text, ns_to_s_map) + + # We first tokenize `orig_text`, strip whitespace from the result + # and `pred_text`, and check if they are the same length. If they are + # NOT the same length, the heuristic has failed. If they are the same + # length, we assume the characters are one-to-one aligned. + tokenizer = BasicTokenizer(do_lower_case=True) + + tok_text = " ".join(tokenizer.tokenize(orig_text)) + + start_position = tok_text.find(pred_text) + if start_position == -1: + if verbose_logging: + print ("=="*10) + print (tok_text) + print("Unable to find text: '%s' in '%s'" % (pred_text, orig_text)) + #return orig_text,0,len(orig_text) + return pred_text.replace(" ",""),0,len(orig_text) + end_position = start_position + len(pred_text) - 1 + + (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text) + (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text) + + if len(orig_ns_text) != len(tok_ns_text): + if verbose_logging: + print("Length not equal after stripping spaces: '%s' vs '%s'", orig_ns_text, tok_ns_text) + #return orig_text,0,len(orig_text) + return pred_text.replace(" ",""),0,len(orig_text) + + # We then project the characters in `pred_text` back to `orig_text` using + # the character-to-character alignment. + tok_s_to_ns_map = {} + for (i, tok_index) in tok_ns_to_s_map.items(): + tok_s_to_ns_map[tok_index] = i + + orig_start_position = None + if start_position in tok_s_to_ns_map: + ns_start_position = tok_s_to_ns_map[start_position] + if ns_start_position in orig_ns_to_s_map: + orig_start_position = orig_ns_to_s_map[ns_start_position] + + if orig_start_position is None: + if verbose_logging: + print("Couldn't map start position") + #return orig_text,0,len(orig_text) + return pred_text.replace(" ",""),0,len(orig_text) + + orig_end_position = None + if end_position in tok_s_to_ns_map: + ns_end_position = tok_s_to_ns_map[end_position] + if ns_end_position in orig_ns_to_s_map: + orig_end_position = orig_ns_to_s_map[ns_end_position] + + if orig_end_position is None: + if verbose_logging: + print("Couldn't map end position") + #return orig_text,0,len(orig_text) + return pred_text.replace(" ",""),0,len(orig_text) + + output_text = orig_text[orig_start_position : (orig_end_position + 1)] + return output_text,orig_start_position,orig_end_position + 1 +def _get_best_start_indexes(logits, n_best_size): + """Get the n-best logits from a list.""" + index_and_score = sorted(logits.items(), key=lambda x: x[1][2], reverse=True) + + best_indexes = [] + for i in range(len(index_and_score)): + if i >= n_best_size: + break + best_indexes.append(index_and_score[i][0]) + return best_indexes + +def _get_best_end_indexes(logits, n_best_size): + """Get the n-best logits from a list.""" + index_and_score = sorted(logits.items(), key=lambda x: x[1], reverse=True) + + best_indexes = [] + for i in range(len(index_and_score)): + if i >= n_best_size: + break + best_indexes.append(index_and_score[i][0]) + return best_indexes +def extract_answer(info): + + start_logits = info["start_logits"] + end_logits = info["end_logits"] + + tokens = [start_logits[str(i)][1] for i in range(len(start_logits))] + + start_indexes = _get_best_start_indexes(start_logits, 10) + end_indexes = _get_best_end_indexes(end_logits, 10) + + _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name + "PrelimPrediction", ["start_index", "end_index", "start_logit", "end_logit"] + ) + + prelim_predictions = [] + + for start_index in start_indexes: + for end_index in end_indexes: + if int(end_index) < int(start_index): + continue + length = int(end_index) - int(start_index) + 1 + if length >30: + continue + prelim_predictions.append( + _PrelimPrediction( + start_index=start_index, + end_index=end_index, + start_logit=start_logits[start_index][2], + end_logit=end_logits[end_index], + ) + ) + + prelim_predictions = sorted(prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True) + _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name + "NbestPrediction", ["text", "start_logit", "end_logit","start_index", "end_index"] + ) + nbest = [] + seen_predictions = {} + final_text = "" + for pred in prelim_predictions: + if len(nbest) >= 10: + break + tok_tokens = tokens[int(pred.start_index):int(pred.end_index)+1] + orig_tokens = info["ori_tokens"][start_logits[pred.start_index][0]:start_logits[pred.end_index][0] + 1] + + tok_text = " ".join(tok_tokens) + tok_text = tok_text.replace(" ##", "") + tok_text = tok_text.replace("##", "") + + tok_text = tok_text.strip() + tok_text = " ".join(tok_text.split()) + orig_text = " ".join(orig_tokens) + + final_text,start_index,end_index = get_final_text(tok_text, orig_text, do_lower_case=False, verbose_logging=True) + break + + if final_text in seen_predictions: + continue + seen_predictions[final_text] = True + + return final_text + +def ensemble_logits(file_list,filename): + + data_list = [json.load(open(os.path.join(f,filename))) for f in file_list] + data_new = data_list[0] + + + + for i in range(1,len(file_list)): + assert data_list[i]["ori_tokens"] == data_new["ori_tokens"] + for index in range(len(data_new["start_logits"])): + assert data_new["start_logits"][str(index)][1] == data_list[i]["start_logits"][str(index)][1] + data_new["start_logits"][str(index)][2] += data_list[i]["start_logits"][str(index)][2] + data_new["end_logits"][str(index)] += data_list[i]["end_logits"][str(index)] + return data_new + +if __name__ == "__main__": + ensemble_data_files = sys.argv[1:-1] + result = collections.OrderedDict() + + for filename in os.listdir(ensemble_data_files[0]): + data = ensemble_logits(ensemble_data_files,filename) + result[filename] = extract_answer(data) + + json.dump(result,open(sys.argv[-1],"w"),ensure_ascii=False,indent=4) + diff --git a/ensemble_by_prob.py b/ensemble_by_prob.py new file mode 100644 index 0000000..329462d --- /dev/null +++ b/ensemble_by_prob.py @@ -0,0 +1,248 @@ +import json +import sys +import collections +from transformers.tokenization_bert import BasicTokenizer +import logging +import numpy as np +#logging = logging.getLogger(__name__) + +def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): + """Project the tokenized prediction back to the original text.""" + + # When we created the data, we kept track of the alignment between original + # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So + # now `orig_text` contains the span of our original text corresponding to the + # span that we predicted. + # + # However, `orig_text` may contain extra characters that we don't want in + # our prediction. + # + # For example, let's say: + # pred_text = steve smith + # orig_text = Steve Smith's + # + # We don't want to return `orig_text` because it contains the extra "'s". + # + # We don't want to return `pred_text` because it's already been normalized + # (the SQuAD eval script also does punctuation stripping/lower casing but + # our tokenizer does additional normalization like stripping accent + # characters). + # + # What we really want to return is "Steve Smith". + # + # Therefore, we have to apply a semi-complicated alignment heuristic between + # `pred_text` and `orig_text` to get a character-to-character alignment. This + # can fail in certain cases in which case we just return `orig_text`. + + def _strip_spaces(text): + ns_chars = [] + ns_to_s_map = collections.OrderedDict() + for (i, c) in enumerate(text): + if c == " ": + continue + ns_to_s_map[len(ns_chars)] = i + ns_chars.append(c) + ns_text = "".join(ns_chars) + return (ns_text, ns_to_s_map) + + # We first tokenize `orig_text`, strip whitespace from the result + # and `pred_text`, and check if they are the same length. If they are + # NOT the same length, the heuristic has failed. If they are the same + # length, we assume the characters are one-to-one aligned. + tokenizer = BasicTokenizer(do_lower_case=True) + + tok_text = " ".join(tokenizer.tokenize(orig_text)) + + start_position = tok_text.find(pred_text) + if start_position == -1: + if verbose_logging: + print ("=="*10) + print (tok_text) + print("Unable to find text: '%s' in '%s'" % (pred_text, orig_text)) + #return orig_text,0,len(orig_text) + return pred_text.replace(" ",""),0,len(orig_text) + end_position = start_position + len(pred_text) - 1 + + (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text) + (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text) + + if len(orig_ns_text) != len(tok_ns_text): + if verbose_logging: + print("Length not equal after stripping spaces: '%s' vs '%s'", orig_ns_text, tok_ns_text) + #return orig_text,0,len(orig_text) + return pred_text.replace(" ",""),0,len(orig_text) + + # We then project the characters in `pred_text` back to `orig_text` using + # the character-to-character alignment. + tok_s_to_ns_map = {} + for (i, tok_index) in tok_ns_to_s_map.items(): + tok_s_to_ns_map[tok_index] = i + + orig_start_position = None + if start_position in tok_s_to_ns_map: + ns_start_position = tok_s_to_ns_map[start_position] + if ns_start_position in orig_ns_to_s_map: + orig_start_position = orig_ns_to_s_map[ns_start_position] + + if orig_start_position is None: + if verbose_logging: + print("Couldn't map start position") + #return orig_text,0,len(orig_text) + return pred_text.replace(" ",""),0,len(orig_text) + + orig_end_position = None + if end_position in tok_s_to_ns_map: + ns_end_position = tok_s_to_ns_map[end_position] + if ns_end_position in orig_ns_to_s_map: + orig_end_position = orig_ns_to_s_map[ns_end_position] + + if orig_end_position is None: + if verbose_logging: + print("Couldn't map end position") + #return orig_text,0,len(orig_text) + return pred_text.replace(" ",""),0,len(orig_text) + + + output_text = orig_text[orig_start_position : (orig_end_position + 1)] + return output_text,orig_start_position,orig_end_position + 1 +def _get_best_start_indexes(logits, n_best_size): + """Get the n-best logits from a list.""" + index_and_score = sorted(logits.items(), key=lambda x: x[1][2], reverse=True) + + best_indexes = [] + for i in range(len(index_and_score)): + if i >= n_best_size: + break + best_indexes.append(index_and_score[i][0]) + return best_indexes + +def _get_best_end_indexes(logits, n_best_size): + """Get the n-best logits from a list.""" + index_and_score = sorted(logits.items(), key=lambda x: x[1], reverse=True) + + best_indexes = [] + for i in range(len(index_and_score)): + if i >= n_best_size: + break + best_indexes.append(index_and_score[i][0]) + return best_indexes +def extract_answer(info): + + start_logits = info["start_logits"] + end_logits = info["end_logits"] + + tokens = [start_logits[str(i)][1] for i in range(len(start_logits))] + + start_indexes = _get_best_start_indexes(start_logits, 10) + end_indexes = _get_best_end_indexes(end_logits, 10) + + _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name + "PrelimPrediction", ["start_index", "end_index", "start_logit", "end_logit"] + ) + + prelim_predictions = [] + + for start_index in start_indexes: + for end_index in end_indexes: + if int(end_index) < int(start_index): + continue + length = int(end_index) - int(start_index) + 1 + if length >30: + continue + prelim_predictions.append( + _PrelimPrediction( + start_index=start_index, + end_index=end_index, + start_logit=start_logits[start_index][2], + end_logit=end_logits[end_index], + ) + ) + + prelim_predictions = sorted(prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True) + _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name + "NbestPrediction", ["text", "start_logit", "end_logit","start_index", "end_index"] + ) + nbest = [] + seen_predictions = {} + final_text = "" + for pred in prelim_predictions: + if len(nbest) >= 10: + break + tok_tokens = tokens[int(pred.start_index):int(pred.end_index)+1] + orig_tokens = info["ori_tokens"][start_logits[pred.start_index][0]:start_logits[pred.end_index][0] + 1] + + tok_text = " ".join(tok_tokens) + tok_text = tok_text.replace(" ##", "") + tok_text = tok_text.replace("##", "") + + tok_text = tok_text.strip() + tok_text = " ".join(tok_text.split()) + orig_text = " ".join(orig_tokens) + + final_text,start_index,end_index = get_final_text(tok_text, orig_text, do_lower_case=False, verbose_logging=True) + break + + if final_text in seen_predictions: + continue + seen_predictions[final_text] = True + + return final_text + + +def softmax(x): + x_row_max = x.max(axis=-1) + x_row_max = x_row_max.reshape(list(x.shape)[:-1]+[1]) + x = x - x_row_max + x_exp = np.exp(x) + x_exp_row_sum = x_exp.sum(axis=-1).reshape(list(x.shape)[:-1]+[1]) + softmax = x_exp / x_exp_row_sum + return softmax + +def ensemble_logits(file_list): + + data_list = [json.load(open(f)) for f in file_list] + + for i in range(len(file_list)): + for qid in data_list[i].keys(): + + logit = [data_list[i][qid]["start_logits"][str(index)][2] for index in range(len(data_list[i][qid]["start_logits"]))] + logit = np.array(logit) + probs = softmax(logit) +# assert sum(probs) == 1 + + for index in range(len(data_list[i][qid]["start_logits"])): + data_list[i][qid]["start_logits"][str(index)][2] = probs[index] + + try: + logit = [data_list[i][qid]["end_logits"][str(index)][2] for index in range(len(data_list[i][qid]["end_logits"]))] + except: + logit = [data_list[i][qid]["end_logits"][str(index)] for index in range(len(data_list[i][qid]["end_logits"]))] + + logit = np.array(logit) + probs = softmax(logit) + for index in range(len(data_list[i][qid]["end_logits"])): + data_list[i][qid]["end_logits"][str(index)] = probs[index] + + + data_new = data_list[0] + for qid in data_list[0].keys(): + + for i in range(1,len(file_list)): + assert data_list[i][qid]["ori_tokens"] == data_new[qid]["ori_tokens"] + for index in range(len(data_new[qid]["start_logits"])): + assert data_new[qid]["start_logits"][str(index)][1] == data_list[i][qid]["start_logits"][str(index)][1] + data_new[qid]["start_logits"][str(index)][2] += data_list[i][qid]["start_logits"][str(index)][2] + data_new[qid]["end_logits"][str(index)] += data_list[i][qid]["end_logits"][str(index)] + return data_new + + +if __name__ == "__main__": + ensemble_data_files = sys.argv[1:-1] + data = ensemble_logits(ensemble_data_files) + + result = collections.OrderedDict() + for qid,logit in data.items(): + result[qid] = extract_answer(logit) + + json.dump(result,open(sys.argv[-1],"w"),ensure_ascii=False,indent=4) + diff --git a/ensemble_by_prob_multi_file.py b/ensemble_by_prob_multi_file.py new file mode 100644 index 0000000..5691aed --- /dev/null +++ b/ensemble_by_prob_multi_file.py @@ -0,0 +1,301 @@ +import json +import sys +import collections +from transformers.tokenization_bert import BasicTokenizer +import logging +import numpy as np +import os +from tqdm import tqdm +import math +#logging = logging.getLogger(__name__) + +def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): + """Project the tokenized prediction back to the original text.""" + + # When we created the data, we kept track of the alignment between original + # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So + # now `orig_text` contains the span of our original text corresponding to the + # span that we predicted. + # + # However, `orig_text` may contain extra characters that we don't want in + # our prediction. + # + # For example, let's say: + # pred_text = steve smith + # orig_text = Steve Smith's + # + # We don't want to return `orig_text` because it contains the extra "'s". + # + # We don't want to return `pred_text` because it's already been normalized + # (the SQuAD eval script also does punctuation stripping/lower casing but + # our tokenizer does additional normalization like stripping accent + # characters). + # + # What we really want to return is "Steve Smith". + # + # Therefore, we have to apply a semi-complicated alignment heuristic between + # `pred_text` and `orig_text` to get a character-to-character alignment. This + # can fail in certain cases in which case we just return `orig_text`. + + def _strip_spaces(text): + ns_chars = [] + ns_to_s_map = collections.OrderedDict() + for (i, c) in enumerate(text): + if c == " ": + continue + ns_to_s_map[len(ns_chars)] = i + ns_chars.append(c) + ns_text = "".join(ns_chars) + return (ns_text, ns_to_s_map) + + # We first tokenize `orig_text`, strip whitespace from the result + # and `pred_text`, and check if they are the same length. If they are + # NOT the same length, the heuristic has failed. If they are the same + # length, we assume the characters are one-to-one aligned. + tokenizer = BasicTokenizer(do_lower_case=True) + + tok_text = " ".join(tokenizer.tokenize(orig_text)) + + start_position = tok_text.find(pred_text) + if start_position == -1: + if verbose_logging: + print ("=="*10) + print (tok_text) + print("Unable to find text: '%s' in '%s'" % (pred_text, orig_text)) + #return orig_text,0,len(orig_text) + return pred_text.replace(" ",""),0,len(orig_text) + end_position = start_position + len(pred_text) - 1 + + (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text) + (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text) + + if len(orig_ns_text) != len(tok_ns_text): + if verbose_logging: + print("Length not equal after stripping spaces: '%s' vs '%s'", orig_ns_text, tok_ns_text) + #return orig_text,0,len(orig_text) + return pred_text.replace(" ",""),0,len(orig_text) + + # We then project the characters in `pred_text` back to `orig_text` using + # the character-to-character alignment. + tok_s_to_ns_map = {} + for (i, tok_index) in tok_ns_to_s_map.items(): + tok_s_to_ns_map[tok_index] = i + + orig_start_position = None + if start_position in tok_s_to_ns_map: + ns_start_position = tok_s_to_ns_map[start_position] + if ns_start_position in orig_ns_to_s_map: + orig_start_position = orig_ns_to_s_map[ns_start_position] + + if orig_start_position is None: + if verbose_logging: + print("Couldn't map start position") + #return orig_text,0,len(orig_text) + return pred_text.replace(" ",""),0,len(orig_text) + + orig_end_position = None + if end_position in tok_s_to_ns_map: + ns_end_position = tok_s_to_ns_map[end_position] + if ns_end_position in orig_ns_to_s_map: + orig_end_position = orig_ns_to_s_map[ns_end_position] + + if orig_end_position is None: + if verbose_logging: + print("Couldn't map end position") + #return orig_text,0,len(orig_text) + return pred_text.replace(" ",""),0,len(orig_text) + + + output_text = orig_text[orig_start_position : (orig_end_position + 1)] + return output_text,orig_start_position,orig_end_position + 1 +def _get_best_start_indexes(logits, n_best_size): + """Get the n-best logits from a list.""" + index_and_score = sorted(logits.items(), key=lambda x: x[1][2], reverse=True) + + best_indexes = [] + for i in range(len(index_and_score)): + if i >= n_best_size: + break + best_indexes.append(index_and_score[i][0]) + return best_indexes + +def _get_best_end_indexes(logits, n_best_size): + """Get the n-best logits from a list.""" + index_and_score = sorted(logits.items(), key=lambda x: x[1], reverse=True) + + best_indexes = [] + for i in range(len(index_and_score)): + if i >= n_best_size: + break + best_indexes.append(index_and_score[i][0]) + return best_indexes + +def _compute_softmax(scores): + """Compute softmax probability over raw logits.""" + if not scores: + return [] + + max_score = None + for score in scores: + if max_score is None or score > max_score: + max_score = score + + exp_scores = [] + total_sum = 0.0 + for score in scores: + x = math.exp(score - max_score) + exp_scores.append(x) + total_sum += x + + probs = [] + for score in exp_scores: + probs.append(score / total_sum) + return probs + + +def extract_answer(info): + + start_logits = info["start_logits"] + end_logits = info["end_logits"] + + tokens = [start_logits[str(i)][1] for i in range(len(start_logits))] + + start_indexes = _get_best_start_indexes(start_logits, 10) + end_indexes = _get_best_end_indexes(end_logits, 10) + + _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name + "PrelimPrediction", ["start_index", "end_index", "start_logit", "end_logit"] + ) + + prelim_predictions = [] + + for start_index in start_indexes: + for end_index in end_indexes: + if int(end_index) < int(start_index): + continue + length = int(end_index) - int(start_index) + 1 + if length >30: + continue + prelim_predictions.append( + _PrelimPrediction( + start_index=start_index, + end_index=end_index, + start_logit=start_logits[start_index][2], + end_logit=end_logits[end_index], + ) + ) + + prelim_predictions = sorted(prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True) + _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name + "NbestPrediction", ["text", "start_logit", "end_logit","start_index", "end_index"] + ) + nbest = [] + seen_predictions = {} + final_text = "" + for pred in prelim_predictions: + if len(nbest) >= 10: + break + tok_tokens = tokens[int(pred.start_index):int(pred.end_index)+1] + orig_tokens = info["ori_tokens"][start_logits[pred.start_index][0]:start_logits[pred.end_index][0] + 1] + + tok_text = " ".join(tok_tokens) + tok_text = tok_text.replace(" ##", "") + tok_text = tok_text.replace("##", "") + + tok_text = tok_text.strip() + tok_text = " ".join(tok_text.split()) + orig_text = " ".join(orig_tokens) + + final_text,start_index,end_index = get_final_text(tok_text, orig_text, do_lower_case=False, verbose_logging=False) + + if final_text in seen_predictions: + continue + seen_predictions[final_text] = True + + nbest.append(_NbestPrediction(text=final_text, start_logit=pred.start_logit, end_logit=pred.end_logit,start_index = start_index,end_index = end_index)) + + if not nbest: + nbest.append(_NbestPrediction(text="", start_log_prob=-1e6, end_log_prob=-1e6)) + + total_scores = [] + best_non_null_entry = None + + + for entry in nbest: + total_scores.append(entry.start_logit + entry.end_logit) + if not best_non_null_entry: + best_non_null_entry = entry + + probs = _compute_softmax(total_scores) + + nbest_json = [] + for (i, entry) in enumerate(nbest): + output = collections.OrderedDict() + output["text"] = entry.text + output["probability"] = probs[i] + output["start_logit"] = entry.start_logit + output["end_logit"] = entry.end_logit + output["start_index"] = entry.start_index + output["end_index"] = entry.end_index + + nbest_json.append(output) + + assert len(nbest_json) >= 1 + assert best_non_null_entry is not None + + return best_non_null_entry.text, nbest_json + + +def softmax(x): + x_row_max = x.max(axis=-1) + x_row_max = x_row_max.reshape(list(x.shape)[:-1]+[1]) + x = x - x_row_max + x_exp = np.exp(x) + x_exp_row_sum = x_exp.sum(axis=-1).reshape(list(x.shape)[:-1]+[1]) + softmax = x_exp / x_exp_row_sum + return softmax + +def ensemble_logits(file_list,filename): + + data_list = [json.load(open(os.path.join(f,filename))) for f in file_list] + data_new = data_list[0] + + for i in range(len(file_list)): + + logit = [data_list[i]["start_logits"][str(index)][2] for index in range(len(data_list[i]["start_logits"]))] + logit = np.array(logit) + probs = softmax(logit) +# assert sum(probs) == 1 + + for index in range(len(data_list[i]["start_logits"])): + data_list[i]["start_logits"][str(index)][2] = probs[index] + + logit = [data_list[i]["end_logits"][str(index)] for index in range(len(data_list[i]["end_logits"]))] + + logit = np.array(logit) + probs = softmax(logit) + for index in range(len(data_list[i]["end_logits"])): + data_list[i]["end_logits"][str(index)] = probs[index] + + + data_new = data_list[0] + + for i in range(1,len(file_list)): + assert data_list[i]["ori_tokens"] == data_new["ori_tokens"] + for index in range(len(data_new["start_logits"])): + assert data_new["start_logits"][str(index)][1] == data_list[i]["start_logits"][str(index)][1] + data_new["start_logits"][str(index)][2] += data_list[i]["start_logits"][str(index)][2] + data_new["end_logits"][str(index)] += data_list[i]["end_logits"][str(index)] + return data_new + +if __name__ == "__main__": + ensemble_data_files = sys.argv[1:-1] + + result = collections.OrderedDict() + nbest_result = collections.OrderedDict() + for filename in tqdm(os.listdir(ensemble_data_files[0])): + data = ensemble_logits(ensemble_data_files,filename) + result[filename],nbest_result[filename] = extract_answer(data) + + json.dump(result,open(sys.argv[-1],"w"),ensure_ascii=False,indent=4) + json.dump(nbest_result,open(sys.argv[-1] + '.nest',"w"),ensure_ascii=False,indent=4) diff --git a/ensemble_by_prob_multi_file_weight.py b/ensemble_by_prob_multi_file_weight.py new file mode 100644 index 0000000..4ddd65c --- /dev/null +++ b/ensemble_by_prob_multi_file_weight.py @@ -0,0 +1,270 @@ +import json +import sys +import collections +from transformers.tokenization_bert import BasicTokenizer +import logging +import numpy as np +import os +#logging = logging.getLogger(__name__) +f1_file = "results/f1.json" +def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): + """Project the tokenized prediction back to the original text.""" + + # When we created the data, we kept track of the alignment between original + # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So + # now `orig_text` contains the span of our original text corresponding to the + # span that we predicted. + # + # However, `orig_text` may contain extra characters that we don't want in + # our prediction. + # + # For example, let's say: + # pred_text = steve smith + # orig_text = Steve Smith's + # + # We don't want to return `orig_text` because it contains the extra "'s". + # + # We don't want to return `pred_text` because it's already been normalized + # (the SQuAD eval script also does punctuation stripping/lower casing but + # our tokenizer does additional normalization like stripping accent + # characters). + # + # What we really want to return is "Steve Smith". + # + # Therefore, we have to apply a semi-complicated alignment heuristic between + # `pred_text` and `orig_text` to get a character-to-character alignment. This + # can fail in certain cases in which case we just return `orig_text`. + + orig_text = orig_text.lower() + def _strip_spaces(text): + ns_chars = [] + ns_to_s_map = collections.OrderedDict() + for (i, c) in enumerate(text): + if c == " ": + continue + ns_to_s_map[len(ns_chars)] = i + ns_chars.append(c) + ns_text = "".join(ns_chars) + return (ns_text, ns_to_s_map) + + # We first tokenize `orig_text`, strip whitespace from the result + # and `pred_text`, and check if they are the same length. If they are + # NOT the same length, the heuristic has failed. If they are the same + # length, we assume the characters are one-to-one aligned. + tokenizer = BasicTokenizer(do_lower_case=True) + + tok_text = " ".join(tokenizer.tokenize(orig_text)) + + start_position = tok_text.find(pred_text) + if start_position == -1: + if verbose_logging: + print ("=="*10) + print (tok_text) + print("Unable to find text: '%s' in '%s'" % (pred_text, orig_text)) + #return orig_text,0,len(orig_text) + return pred_text.replace(" ",""),0,len(orig_text) + end_position = start_position + len(pred_text) - 1 + + (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text) + (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text) + + if len(orig_ns_text) != len(tok_ns_text): + if verbose_logging: + print("Length not equal after stripping spaces: '%s' vs '%s'", orig_ns_text, tok_ns_text) + #return orig_text,0,len(orig_text) + return pred_text.replace(" ",""),0,len(orig_text) + + # We then project the characters in `pred_text` back to `orig_text` using + # the character-to-character alignment. + tok_s_to_ns_map = {} + for (i, tok_index) in tok_ns_to_s_map.items(): + tok_s_to_ns_map[tok_index] = i + + orig_start_position = None + if start_position in tok_s_to_ns_map: + ns_start_position = tok_s_to_ns_map[start_position] + if ns_start_position in orig_ns_to_s_map: + orig_start_position = orig_ns_to_s_map[ns_start_position] + + if orig_start_position is None: + if verbose_logging: + print("Couldn't map start position") + #return orig_text,0,len(orig_text) + return pred_text.replace(" ",""),0,len(orig_text) + + orig_end_position = None + if end_position in tok_s_to_ns_map: + ns_end_position = tok_s_to_ns_map[end_position] + if ns_end_position in orig_ns_to_s_map: + orig_end_position = orig_ns_to_s_map[ns_end_position] + + if orig_end_position is None: + if verbose_logging: + print("Couldn't map end position") + #return orig_text,0,len(orig_text) + return pred_text.replace(" ",""),0,len(orig_text) + + + output_text = orig_text[orig_start_position : (orig_end_position + 1)] + return output_text,orig_start_position,orig_end_position + 1 +def _get_best_start_indexes(logits, n_best_size): + """Get the n-best logits from a list.""" + index_and_score = sorted(logits.items(), key=lambda x: x[1][2], reverse=True) + + best_indexes = [] + for i in range(len(index_and_score)): + if i >= n_best_size: + break + best_indexes.append(index_and_score[i][0]) + return best_indexes + +def _get_best_end_indexes(logits, n_best_size): + """Get the n-best logits from a list.""" + index_and_score = sorted(logits.items(), key=lambda x: x[1], reverse=True) + + best_indexes = [] + for i in range(len(index_and_score)): + if i >= n_best_size: + break + best_indexes.append(index_and_score[i][0]) + return best_indexes +def extract_answer(info): + + start_logits = info["start_logits"] + end_logits = info["end_logits"] + + tokens = [start_logits[str(i)][1] for i in range(len(start_logits))] + + start_indexes = _get_best_start_indexes(start_logits, 10) + end_indexes = _get_best_end_indexes(end_logits, 10) + + _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name + "PrelimPrediction", ["start_index", "end_index", "start_logit", "end_logit"] + ) + + prelim_predictions = [] + + for start_index in start_indexes: + for end_index in end_indexes: + if int(end_index) < int(start_index): + continue + length = int(end_index) - int(start_index) + 1 + if length >30: + continue + prelim_predictions.append( + _PrelimPrediction( + start_index=start_index, + end_index=end_index, + start_logit=start_logits[start_index][2], + end_logit=end_logits[end_index], + ) + ) + + prelim_predictions = sorted(prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True) + _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name + "NbestPrediction", ["text", "start_logit", "end_logit","start_index", "end_index"] + ) + nbest = [] + seen_predictions = {} + final_text = "" + for pred in prelim_predictions: + if len(nbest) >= 10: + break + tok_tokens = tokens[int(pred.start_index):int(pred.end_index)+1] + orig_tokens = info["ori_tokens"][start_logits[pred.start_index][0]:start_logits[pred.end_index][0] + 1] + + tok_text = " ".join(tok_tokens) + tok_text = tok_text.replace(" ##", "") + tok_text = tok_text.replace("##", "") + + tok_text = tok_text.strip() + tok_text = " ".join(tok_text.split()) + orig_text = " ".join(orig_tokens) + + final_text,start_index,end_index = get_final_text(tok_text, orig_text, do_lower_case=True, verbose_logging=True) + break + + if final_text in seen_predictions: + continue + seen_predictions[final_text] = True + + return final_text + + +def softmax(x): + x_row_max = x.max(axis=-1) + x_row_max = x_row_max.reshape(list(x.shape)[:-1]+[1]) + x = x - x_row_max + x_exp = np.exp(x) + x_exp_row_sum = x_exp.sum(axis=-1).reshape(list(x.shape)[:-1]+[1]) + softmax = x_exp / x_exp_row_sum + return softmax + +def ensemble_logits(file_list,filename,weight): + + data_list = [json.load(open(os.path.join(f,filename))) for f in file_list] + data_new = data_list[0] + + for i in range(len(file_list)): + + logit = [data_list[i]["start_logits"][str(index)][2] for index in range(len(data_list[i]["start_logits"]))] + logit = np.array(logit) + probs = softmax(logit) +# assert sum(probs) == 1 + + for index in range(len(data_list[i]["start_logits"])): + data_list[i]["start_logits"][str(index)][2] = probs[index]*weight[i] + + logit = [data_list[i]["end_logits"][str(index)] for index in range(len(data_list[i]["end_logits"]))] + + logit = np.array(logit) + probs = softmax(logit) + for index in range(len(data_list[i]["end_logits"])): + data_list[i]["end_logits"][str(index)] = probs[index]*weight[i] + + + data_new = data_list[0] + + + for i in range(1,len(file_list)): + assert data_list[i]["ori_tokens"] == data_new["ori_tokens"] + for index in range(len(data_new["start_logits"])): + assert data_new["start_logits"][str(index)][1] == data_list[i]["start_logits"][str(index)][1] + data_new["start_logits"][str(index)][2] += data_list[i]["start_logits"][str(index)][2] + data_new["end_logits"][str(index)] += data_list[i]["end_logits"][str(index)] + return data_new + +def calc_weight(ensemble_data_files): + weight = [] + try: + file_f1 = json.load(open(f1_file)) + except: + file_f1 = {} + for f in ensemble_data_files: + if f.endswith("/"): + f = f[:-1] + if f in file_f1: + weight.append(float(file_f1[f])) + else: + w = input("f1 for {}:".format(f)) + weight.append(float(w)) + file_f1[f] = w + json.dump(file_f1,open(f1_file,"w"),indent=4) + + for i,w in enumerate(weight): + weight[i] = (w-71.5)/max(weight) + print (weight) + return weight +if __name__ == "__main__": + ensemble_data_files = sys.argv[1:-1] + + result = collections.OrderedDict() + + weight = calc_weight(ensemble_data_files) + + for filename in os.listdir(ensemble_data_files[0]): + data = ensemble_logits(ensemble_data_files,filename,weight) + result[filename] = extract_answer(data) + + json.dump(result,open(sys.argv[-1],"w"),ensure_ascii=False,indent=4) + diff --git a/ensemble_by_prob_weight.py b/ensemble_by_prob_weight.py new file mode 100644 index 0000000..3980228 --- /dev/null +++ b/ensemble_by_prob_weight.py @@ -0,0 +1,264 @@ +import json +import sys +import collections +from transformers.tokenization_bert import BasicTokenizer +import logging +import numpy as np +#logging = logging.getLogger(__name__) + +f1_file = "results/f1.json" +def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): + """Project the tokenized prediction back to the original text.""" + + # When we created the data, we kept track of the alignment between original + # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So + # now `orig_text` contains the span of our original text corresponding to the + # span that we predicted. + # + # However, `orig_text` may contain extra characters that we don't want in + # our prediction. + # + # For example, let's say: + # pred_text = steve smith + # orig_text = Steve Smith's + # + # We don't want to return `orig_text` because it contains the extra "'s". + # + # We don't want to return `pred_text` because it's already been normalized + # (the SQuAD eval script also does punctuation stripping/lower casing but + # our tokenizer does additional normalization like stripping accent + # characters). + # + # What we really want to return is "Steve Smith". + # + # Therefore, we have to apply a semi-complicated alignment heuristic between + # `pred_text` and `orig_text` to get a character-to-character alignment. This + # can fail in certain cases in which case we just return `orig_text`. + + def _strip_spaces(text): + ns_chars = [] + ns_to_s_map = collections.OrderedDict() + for (i, c) in enumerate(text): + if c == " ": + continue + ns_to_s_map[len(ns_chars)] = i + ns_chars.append(c) + ns_text = "".join(ns_chars) + return (ns_text, ns_to_s_map) + + # We first tokenize `orig_text`, strip whitespace from the result + # and `pred_text`, and check if they are the same length. If they are + # NOT the same length, the heuristic has failed. If they are the same + # length, we assume the characters are one-to-one aligned. + tokenizer = BasicTokenizer(do_lower_case=True) + + tok_text = " ".join(tokenizer.tokenize(orig_text)) + + start_position = tok_text.find(pred_text) + if start_position == -1: + if verbose_logging: + print ("=="*10) + print (tok_text) + print("Unable to find text: '%s' in '%s'" % (pred_text, orig_text)) + #return orig_text,0,len(orig_text) + return pred_text.replace(" ",""),0,len(orig_text) + end_position = start_position + len(pred_text) - 1 + + (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text) + (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text) + + if len(orig_ns_text) != len(tok_ns_text): + if verbose_logging: + print("Length not equal after stripping spaces: '%s' vs '%s'", orig_ns_text, tok_ns_text) + #return orig_text,0,len(orig_text) + return pred_text.replace(" ",""),0,len(orig_text) + + # We then project the characters in `pred_text` back to `orig_text` using + # the character-to-character alignment. + tok_s_to_ns_map = {} + for (i, tok_index) in tok_ns_to_s_map.items(): + tok_s_to_ns_map[tok_index] = i + + orig_start_position = None + if start_position in tok_s_to_ns_map: + ns_start_position = tok_s_to_ns_map[start_position] + if ns_start_position in orig_ns_to_s_map: + orig_start_position = orig_ns_to_s_map[ns_start_position] + + if orig_start_position is None: + if verbose_logging: + print("Couldn't map start position") + #return orig_text,0,len(orig_text) + return pred_text.replace(" ",""),0,len(orig_text) + + orig_end_position = None + if end_position in tok_s_to_ns_map: + ns_end_position = tok_s_to_ns_map[end_position] + if ns_end_position in orig_ns_to_s_map: + orig_end_position = orig_ns_to_s_map[ns_end_position] + + if orig_end_position is None: + if verbose_logging: + print("Couldn't map end position") + #return orig_text,0,len(orig_text) + return pred_text.replace(" ",""),0,len(orig_text) + + + output_text = orig_text[orig_start_position : (orig_end_position + 1)] + return output_text,orig_start_position,orig_end_position + 1 +def _get_best_start_indexes(logits, n_best_size): + """Get the n-best logits from a list.""" + index_and_score = sorted(logits.items(), key=lambda x: x[1][2], reverse=True) + + best_indexes = [] + for i in range(len(index_and_score)): + if i >= n_best_size: + break + best_indexes.append(index_and_score[i][0]) + return best_indexes + +def _get_best_end_indexes(logits, n_best_size): + """Get the n-best logits from a list.""" + index_and_score = sorted(logits.items(), key=lambda x: x[1], reverse=True) + + best_indexes = [] + for i in range(len(index_and_score)): + if i >= n_best_size: + break + best_indexes.append(index_and_score[i][0]) + return best_indexes +def extract_answer(info): + + start_logits = info["start_logits"] + end_logits = info["end_logits"] + + tokens = [start_logits[str(i)][1] for i in range(len(start_logits))] + + start_indexes = _get_best_start_indexes(start_logits, 10) + end_indexes = _get_best_end_indexes(end_logits, 10) + + _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name + "PrelimPrediction", ["start_index", "end_index", "start_logit", "end_logit"] + ) + + prelim_predictions = [] + + for start_index in start_indexes: + for end_index in end_indexes: + if int(end_index) < int(start_index): + continue + length = int(end_index) - int(start_index) + 1 + if length >30: + continue + prelim_predictions.append( + _PrelimPrediction( + start_index=start_index, + end_index=end_index, + start_logit=start_logits[start_index][2], + end_logit=end_logits[end_index], + ) + ) + + prelim_predictions = sorted(prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True) + _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name + "NbestPrediction", ["text", "start_logit", "end_logit","start_index", "end_index"] + ) + nbest = [] + seen_predictions = {} + final_text = "" + for pred in prelim_predictions: + if len(nbest) >= 10: + break + tok_tokens = tokens[int(pred.start_index):int(pred.end_index)+1] + orig_tokens = info["ori_tokens"][start_logits[pred.start_index][0]:start_logits[pred.end_index][0] + 1] + + tok_text = " ".join(tok_tokens) + tok_text = tok_text.replace(" ##", "") + tok_text = tok_text.replace("##", "") + + tok_text = tok_text.strip() + tok_text = " ".join(tok_text.split()) + orig_text = " ".join(orig_tokens) + + final_text,start_index,end_index = get_final_text(tok_text, orig_text, do_lower_case=False, verbose_logging=True) + break + + if final_text in seen_predictions: + continue + seen_predictions[final_text] = True + + return final_text + + +def softmax(x): + x_row_max = x.max(axis=-1) + x_row_max = x_row_max.reshape(list(x.shape)[:-1]+[1]) + x = x - x_row_max + x_exp = np.exp(x) + x_exp_row_sum = x_exp.sum(axis=-1).reshape(list(x.shape)[:-1]+[1]) + softmax = x_exp / x_exp_row_sum + return softmax + +def ensemble_logits(file_list,weight): + + data_list = [json.load(open(f)) for f in file_list] + + for i in range(len(file_list)): + for qid in data_list[i].keys(): + + logit = [data_list[i][qid]["start_logits"][str(index)][2] for index in range(len(data_list[i][qid]["start_logits"]))] + logit = np.array(logit) + probs = softmax(logit) +# assert sum(probs) == 1 + + for index in range(len(data_list[i][qid]["start_logits"])): + data_list[i][qid]["start_logits"][str(index)][2] = probs[index] * weight[i] + + try: + logit = [data_list[i][qid]["end_logits"][str(index)][2] for index in range(len(data_list[i][qid]["end_logits"]))] + except: + logit = [data_list[i][qid]["end_logits"][str(index)] for index in range(len(data_list[i][qid]["end_logits"]))] + + logit = np.array(logit) + probs = softmax(logit) + for index in range(len(data_list[i][qid]["end_logits"])): + data_list[i][qid]["end_logits"][str(index)] = probs[index]* weight[i] + + + data_new = data_list[0] + for qid in data_list[0].keys(): + + for i in range(1,len(file_list)): + assert data_list[i][qid]["ori_tokens"] == data_new[qid]["ori_tokens"] + for index in range(len(data_new[qid]["start_logits"])): + assert data_new[qid]["start_logits"][str(index)][1] == data_list[i][qid]["start_logits"][str(index)][1] + data_new[qid]["start_logits"][str(index)][2] += data_list[i][qid]["start_logits"][str(index)][2] + data_new[qid]["end_logits"][str(index)] += data_list[i][qid]["end_logits"][str(index)] + return data_new + +def calc_weight(ensemble_data_files): + weight = [] + try: + file_f1 = json.load(open(f1_file)) + except: + file_f1 = {} + for f in ensemble_data_files: + if f in file_f1: + weight.append(float(file_f1[f])*0.01) + else: + w = input("f1 for {}:".format(f)) + weight.append(float(w)*0.01) + file_f1[f] = w + json.dump(file_f1,open(f1_file,"w"),indent=4) + return weight +if __name__ == "__main__": + ensemble_data_files = sys.argv[1:-1] + weight = calc_weight(ensemble_data_files) + data = ensemble_logits(ensemble_data_files,weight) + + result = collections.OrderedDict() + for qid,logit in data.items(): + result[qid] = extract_answer(logit) + + json.dump(result,open(sys.argv[-1],"w"),ensure_ascii=False,indent=4) + diff --git a/ensemble_weight.py b/ensemble_weight.py new file mode 100644 index 0000000..8017868 --- /dev/null +++ b/ensemble_weight.py @@ -0,0 +1,213 @@ +# -*- coding: utf-8 -*- +import json +import argparse +import numpy as np +from collections import OrderedDict +import os +__author__ = "aitingliu@bupt.edu.cn" + +path = '/home/ljp/data/lic2020/' + +single_dict = { + 'output_albert_xxlarge_utf8': + { + '2_1': 75.66131, + '2_3': 74.78274, + '3_2': 74.83024, + '4_2': 74.71253 + }, + 'output_data_join_utf8': + { + '1_2': 73.90175, + '5_1': 75.89828, + '6_2': 74.77744, + '8_3': 76.59578, + '14_3': 77.01477, + '17_2': 77.51061, + '17_3':78.87438, + '22_5': 77.63886, + '23_2': 77.44467, + '23_3': 77.96674, + '23_4':78.5541, + '23_5': 78.48484, + '25_4':78.18604, + '26_1': 77.13503, + '26_2':77.49921, + '26_5':78.35698, + '27_1':78.2233, + '27_2':77.68519, + '27_5':78.04175, + '28_5':78.26716, + }, + 'output_roberta_utf8': + { + '2_5': 75.2279, + '3_2': 74.54158, + '4_4': 75.02555, + '5_4': 74.44874, + '6_3': 74.87999 + }, + 'output_utf8': + { + '0': 69.42017, + '1': 70.98569, + '2': 73.52785, + '3': 73.26201, + '4': 71.64866, + '5': 71.30978, + '6': 72.71333, + '7': 73.31789, + '11':73.84171, + '12':75.60135 + } +} + +v_dict = { + 'output_data_join_utf8': + { + '14': 77.01477, + '17': 77.51061, + '21': 76.99629, + '22': 77.63886, + '23': 77.44467, + '25':78.18604, + '26': 77.13503, + '27':77.68519, + '28':78.26716, + '29':77.75663, + '33':77.09937 + } + } + + + +def ensemble_v1(pred_file, division, input_files): + """ + text出现频次 和 probability 加权排序 + """ + if 8 in input_files: + input('error! model 8 doesn\'t have the same key as other model! ') + res_list = [json.load(open(path+"output_utf8/{}_{}_nbest_predictions_utf8.json".format(i, division))) for i in input_files] + for item in res_list: + print(len(item)) + res_json = {} + + for k in list(res_list[0].keys()): + text_list = {} + for i in range(len(res_list)): + for j in range(len(res_list[i][k])): + text = res_list[i][k][j]["text"] + prob = res_list[i][k][j]["probability"] + if not text_list.get(text): + # TODO(aitingliu): start_logit 和 end_logit也可以加进来,看看有没有效果增强 + text_list[text] = 1 * prob # 1 * prob + else: + text_list[text] += 1 * prob # 1 * prob + # print(text_list) + # print(sorted(text_list.items(), key=lambda d: d[1], reverse=True)) + res_json[k] = sorted(text_list.items(), key=lambda d: d[1], reverse=True)[0][0] + + json.dump(res_json, open(pred_file, "w"), ensure_ascii=False, indent=4) + +def ensemble_v4_multiple(res_list, pred_file, weight): + """ + res_list: 单模型输出的json + 对模型在test1上F1的权重w做下列计算作为输入的weight + weight = (w-50) / (w.max()) + input_files = [2, 3, 6, 7] + """ + res_json = {} + try: + for k in list(res_list[0].keys()): + text_list = {} + for i in range(len(res_list)): + for j in range(len(res_list[i][k])): + text = res_list[i][k][j]["text"] + prob = res_list[i][k][j]["probability"] + if not text_list.get(text): + text_list[text] = weight[i] * prob # 1 * prob + else: + text_list[text] += weight[i] * prob # 1 * prob + res_json[k] = sorted(text_list.items(), key=lambda d: d[1], reverse=True)[0][0] +# json.dump(res_json, open(pred_file, "w"), ensure_ascii=False, indent=4) + return res_json + except Exception as e1: + print("err1: ", pred_file, e1) + +def get_json_dict(division, model_lst=[]): + json_dict = {} + if len(model_lst) != 0: # 只load指定文件 + for item in model_lst: + model_type, param = item.split('/') + obj = json.load(open(datapath+'{}_{}_nbest_predictions_utf8.json'.format(item, division))) + try: + json_dict[model_type+'/'+ param] = (obj, single_dict[model_type][param]) + except: + json_dict[model_type+'/'+ param] = (obj, v_dict[model_type][param.split("_")[0]]) + return json_dict + + # load全部文件 + for model_type, param_dict in single_dict.items(): + for param, f1 in param_dict.items(): + obj = json.load(open(datapath+'{}/{}_{}_nbest_predictions_utf8.json'.format(model_type, param, division))) + json_dict[model_type+'/'+ param] = (obj, f1) + return json_dict + + +def ensemble_one(model_lst, division, version, weight_all=None): + obj_dict = get_json_dict(division, model_lst) + res_list = [obj_dict[key][0] for key in model_lst] + if version == 'v4': + f1_np = np.array([obj_dict[key][1] for key in model_lst]) + f1_np = (f1_np - weight_all) / f1_np.max() + + if version == 'v4': + file_name = '_'.join(model_lst) + '_w' + str(weight_all) + '.json' + pred_ans = ensemble_v4_multiple(res_list, file_name, f1_np) + elif version == 'v1' and weight_all == None: + file_name = '_'.join(model_lst) + '.json' + pred_ans = ensemble_v1_multiple(res_list, file_name) + else: + input('version input error! version=', version) + + sim_name = [] + for item in model_lst: + if 'xxlarge' in item: + sim_name.append('xx'+item.split('/')[1]) + elif 'output_roberta_utf8' in item: + sim_name.append('ro'+item.split('/')[1]) + elif 'output_data_join_utf8' in item: + sim_name.append('dj'+item.split('/')[1]) + else: + sim_name.append(item.split('/')[1]) + if version == 'v4': + file_name = '-'.join(sim_name) + '_w' + str(weight_all) + '.json' + elif version == 'v1': + file_name = '-'.join(sim_name) + '.json' + else: + input('version input error! version=', version) + file_name = '{}_ensemble_{}_'.format(division, version) + file_name + print('file_name: ', file_name) + json.dump(pred_ans, open(os.path.join(export_datapath, file_name), "w"), ensure_ascii=False, indent=4) + +if __name__ == "__main__": + ''' + 参数: group datapath export_datapath constant division + ''' + group = {'output_data_join_utf8': ["14_3", "23_3", "26_5", "17_2", "23_4", "17_3", "27_5", "22_5", "25_4", "28_5"]} # 按照single_dict的模型类型,指定参与ensemble的模型名称。原始json文件名类似:14_3_test1_nbest_predictions_utf8.json + datapath = './results/' # 输出的ensemble json存放路径 + export_datapath = '/home_export/bzw/MRC/code/lic2020/results/ensemble_test1'# 新ensemble json保存目录 + constant = 72 # constant + division = 'test1' # test1/train/dev + version = 'v4' # ensemble version + + + model_name_list = [] + for name in group: + for sub in group[name]: + model_name_list.append(name+'/'+sub) + model_name_list = sorted(model_name_list) + print(model_name_list) + ensemble_one(model_name_list, division, version, constant) + + diff --git a/ensemble_weight2.py b/ensemble_weight2.py new file mode 100644 index 0000000..7fe841d --- /dev/null +++ b/ensemble_weight2.py @@ -0,0 +1,298 @@ +# -*- coding: utf-8 -*- +import json +import argparse +import numpy as np +from collections import OrderedDict +import os +import itertools +__author__ = "aitingliu@bupt.edu.cn" + +path = '/home/ljp/data/lic2020/' + +single_dict = { + 'output_albert_xxlarge_utf8': + { + '2_1': 75.66131, + '2_3': 74.78274, + '3_2': 74.83024, + '4_2': 74.71253 + }, + 'output_data_join_utf8': + { + '1_2': 73.90175, + '5_1': 75.89828, + '6_2': 74.77744, + '8_3': 76.59578, + '14_3': 77.01477, + '17_2': 77.51061, + '17_3':78.87438, + '22_5': 77.63886, + '23_2': 77.44467, + '23_3': 77.96674, + '23_4':78.5541, + '23_5': 78.48484, + '25_4':78.18604, + '26_1': 77.13503, + '26_2':77.49921, + '26_5':78.35698, + '27_1':78.2233, + '27_2':77.68519, + '27_5':78.04175, + '28_5':78.26716, + }, + 'output_roberta_utf8': + { + '2_5': 75.2279, + '3_2': 74.54158, + '4_4': 75.02555, + '5_4': 74.44874, + '6_3': 74.87999 + }, + 'output_utf8': + { + '0': 69.42017, + '1': 70.98569, + '2': 73.52785, + '3': 73.26201, + '4': 71.64866, + '5': 71.30978, + '6': 72.71333, + '7': 73.31789, + '11':73.84171, + '12':75.60135 + } +} + +v_dict = { + 'output_data_join_utf8': + { + '14': 77.01477, + '17': 77.51061, + '22': 77.63886, + '23': 77.44467, + '25':78.18604, + '26': 77.13503, + '27':77.68519, + '28':78.26716, + '29':77.75663, + '33':77.09937 + } + } +dev_single_dict = { + 'output_data_join_utf8': + { + '17_5': 85.771, + '17_4': 85.971, + '17_3': 86.51, + '17_2': 86.537, + '17_1': 85.641, + '22_5': 86.64, + '22_4': 86.234, + '22_3': 86.229, + '22_2': 85.821, + '22_1': 85.814, + '14_5': 85.666, + '14_4': 85.411, + '14_3': 85.731, + '14_2': 85.4, + '14_1': 85.519, + '23_5': 86.06, + '23_4': 86.225, + '23_3': 86.972, + '23_2': 86.976, + '23_1': 86.136, + '25_5': 86.295, + '25_4': 86.476, + '25_3': 86.349, + '25_2': 86.797, + '25_1': 85.381, + '26_5': 86.642, + '26_4': 87.041, + '26_3': 87.271, + '26_2': 87.644, + '26_1': 87.684, + '27_5': 85.781, + '27_4': 85.954, + '27_3': 85.912, + '27_2': 86.526, + '27_1': 86.718, + '28_5': 86.424, + '28_4': 86.08, + '28_3': 86.395, + '28_2': 86.267, + '28_1': 86.162, + '29_5': 86.402, + '29_4': 86.656, + '29_3': 86.843, + '29_2': 86.492, + '29_1': 86.284, + '33_5': 86.207, + '33_4': 85.93, + '33_3': 86.022, + '33_2': 86.087, + '33_1': 86.477, + '34_5': 86.913, + '34_4': 86.801, + '34_3': 87.01, + '34_2': 87.08, + '34_1': 85.877, + '35_5': 86.573, + '35_4': 86.549, + '35_3': 85.996, + '35_2': 85.769, + '35_1': 85.574, + '36_5': 86.799, + '36_4': 86.097, + '36_3': 86.22, + '36_2': 86.207, + '36_1': 86.691, + '37_5': 86.936, + '37_4': 87.146, + '37_3': 87.259, + '37_2': 87.509, + '37_1': 87.592, + } +} +def ensemble_v1(pred_file, division, input_files): + """ + text出现频次 和 probability 加权排序 + """ + if 8 in input_files: + input('error! model 8 doesn\'t have the same key as other model! ') + res_list = [json.load(open(path+"output_utf8/{}_{}_nbest_predictions_utf8.json".format(i, division))) for i in input_files] + for item in res_list: + print(len(item)) + res_json = {} + + for k in list(res_list[0].keys()): + text_list = {} + for i in range(len(res_list)): + for j in range(len(res_list[i][k])): + text = res_list[i][k][j]["text"] + prob = res_list[i][k][j]["probability"] + if not text_list.get(text): + # TODO(aitingliu): start_logit 和 end_logit也可以加进来,看看有没有效果增强 + text_list[text] = 1 * prob # 1 * prob + else: + text_list[text] += 1 * prob # 1 * prob + # print(text_list) + # print(sorted(text_list.items(), key=lambda d: d[1], reverse=True)) + res_json[k] = sorted(text_list.items(), key=lambda d: d[1], reverse=True)[0][0] + + json.dump(res_json, open(pred_file, "w"), ensure_ascii=False, indent=4) + +def ensemble_v4_multiple(res_list, pred_file, weight): + """ + res_list: 单模型输出的json + 对模型在test1上F1的权重w做下列计算作为输入的weight + weight = (w-50) / (w.max()) + input_files = [2, 3, 6, 7] + """ + res_json = {} + try: + for k in list(res_list[0].keys()): + text_list = {} + for i in range(len(res_list)): + for j in range(len(res_list[i][k])): + text = res_list[i][k][j]["text"] + prob = res_list[i][k][j]["probability"] + if not text_list.get(text): + text_list[text] = weight[i] * prob # 1 * prob + else: + text_list[text] += weight[i] * prob # 1 * prob + res_json[k] = sorted(text_list.items(), key=lambda d: d[1], reverse=True)[0][0] +# json.dump(res_json, open(pred_file, "w"), ensure_ascii=False, indent=4) + return res_json + except Exception as e1: + print("err1: ", pred_file, e1) + +def get_json_dict(division, model_lst=[]): + + if division == "dev": + single_dict = dev_single_dict + json_dict = {} + if len(model_lst) != 0: # 只load指定文件 + for item in model_lst: + model_type, param = item.split('/') + obj = json.load(open(datapath+'{}_{}_nbest_predictions_utf8.json'.format(item, division))) + try: + json_dict[model_type+'/'+ param] = (obj, single_dict[model_type][param]) + except: + json_dict[model_type+'/'+ param] = (obj, v_dict[model_type][param.split("_")[0]]) + return json_dict + + # load全部文件 + for model_type, param_dict in single_dict.items(): + for param, f1 in param_dict.items(): + obj = json.load(open(datapath+'{}/{}_{}_nbest_predictions_utf8.json'.format(model_type, param, division))) + json_dict[model_type+'/'+ param] = (obj, f1) + return json_dict + + +def ensemble_one(model_lst, division, version, weight_all=None): + obj_dict = get_json_dict(division, model_lst) + res_list = [obj_dict[key][0] for key in model_lst] + if version == 'v4': + f1_np = np.array([obj_dict[key][1] for key in model_lst]) + f1_np = (f1_np - weight_all) / f1_np.max() + + if version == 'v4': + file_name = '_'.join(model_lst) + '_w' + str(weight_all) + '.json' + pred_ans = ensemble_v4_multiple(res_list, file_name, f1_np) + elif version == 'v1' and weight_all == None: + file_name = '_'.join(model_lst) + '.json' + pred_ans = ensemble_v1_multiple(res_list, file_name) + else: + input('version input error! version=', version) + + sim_name = [] + for item in model_lst: + if 'xxlarge' in item: + sim_name.append('xx'+item.split('/')[1]) + elif 'output_roberta_utf8' in item: + sim_name.append('ro'+item.split('/')[1]) + elif 'output_data_join_utf8' in item: + sim_name.append('dj'+item.split('/')[1].split("_")[0]) + else: + sim_name.append(item.split('/')[1]) + sim_name = list(set(sim_name)) + if version == 'v4': + file_name = '-'.join(sim_name) + '_w' + str(weight_all) + '.json' + elif version == 'v1': + file_name = '-'.join(sim_name) + '.json' + else: + input('version input error! version=', version) + file_name = '{}_ensemble_{}_'.format(division, version) + file_name + print('file_name: ', file_name) + json.dump(pred_ans, open(os.path.join(export_datapath, file_name), "w"), ensure_ascii=False, indent=4) + +if __name__ == "__main__": + ''' + 参数: group datapath export_datapath constant division + ''' + #group = {'output_data_join_utf8': ["26_2", "14_3", "23_3", "26_5", "17_2", "23_4", "27_1", "27_2", "17_3", "23_5", "27_5", "22_5", "25_4", "28_5", "23_2", "26_1"]} # 按照single_dict的模型类型,指定参与ensemble的模型名称。原始json文件名类似:14_3_test1_nbest_predictions_utf8.json + Ls = ["14","17","22","23","25","26","27","28","29","33","34", "35","36", "37"] + #L = ["14","17","22","23","25","26","27","28"]#之前看过的模型 + #L = ["17","23","25","26","27","28"] #大于78的模型 + for num in range(2,len(Ls) + 1): + for L in itertools.combinations(Ls,num): + group = {'output_data_join_utf8':[]} + for i in range(4,6): + for l in L: + group['output_data_join_utf8'] += [l+"_"+str(i)] + datapath = './results/' # 输出的ensemble json存放路径 + export_datapath = '/home_export/bzw/MRC/code/lic2020/results/ensemble_dev/EP4-5'# 新ensemble json保存目录 + constant = 71.5 # constant + division = 'dev' # test1/train/dev + version = 'v4' # ensemble version + + + model_name_list = [] + for name in group: + for sub in group[name]: + model_name_list.append(name+'/'+sub) + model_name_list = sorted(model_name_list) + print(model_name_list) + ensemble_one(model_name_list, division, version, constant) + + diff --git a/ensemble_without_weight.py b/ensemble_without_weight.py new file mode 100644 index 0000000..3c0eff9 --- /dev/null +++ b/ensemble_without_weight.py @@ -0,0 +1,198 @@ +# -*- coding: utf-8 -*- +import json +import argparse +import numpy as np +from collections import OrderedDict +import os +__author__ = "aitingliu@bupt.edu.cn" + +path = '/home/ljp/data/lic2020/' + +single_dict = { + 'output_albert_xxlarge_utf8': + { + '2_1': 75.66131, + '2_3': 74.78274, + '3_2': 74.83024, + '4_2': 74.71253 + }, + 'output_data_join_utf8': + { + '1_2': 73.90175, + '5_1': 75.89828, + '6_2': 74.77744, + '8_3': 76.59578, + '14_3': 77.01477, + '17_2': 77.51061, + '17_3':78.87438, + '22_5': 77.63886, + '23_2': 77.44467, + '23_3': 77.96674, + '23_4':78.5541, + '23_5': 78.48484, + '25_4':78.18604, + '26_1': 77.13503, + '26_2':77.49921, + '26_5':78.35698, + '27_1':78.2233, + '27_2':77.68519, + '27_5':78.04175, + '28_5':78.26716, + }, + 'output_roberta_utf8': + { + '2_5': 75.2279, + '3_2': 74.54158, + '4_4': 75.02555, + '5_4': 74.44874, + '6_3': 74.87999 + }, + 'output_utf8': + { + '0': 69.42017, + '1': 70.98569, + '2': 73.52785, + '3': 73.26201, + '4': 71.64866, + '5': 71.30978, + '6': 72.71333, + '7': 73.31789, + '11':73.84171, + '12':75.60135 + } +} + +def ensemble_v1(pred_file, division, input_files): + """ + text出现频次 和 probability 加权排序 + """ + if 8 in input_files: + input('error! model 8 doesn\'t have the same key as other model! ') + res_list = [json.load(open(path+"output_utf8/{}_{}_nbest_predictions_utf8.json".format(i, division))) for i in input_files] + for item in res_list: + print(len(item)) + res_json = {} + + for k in list(res_list[0].keys()): + text_list = {} + for i in range(len(res_list)): + for j in range(len(res_list[i][k])): + text = res_list[i][k][j]["text"] + prob = res_list[i][k][j]["probability"] + if not text_list.get(text): + # TODO(aitingliu): start_logit 和 end_logit也可以加进来,看看有没有效果增强 + text_list[text] = 1 * prob # 1 * prob + else: + text_list[text] += 1 * prob # 1 * prob + # print(text_list) + # print(sorted(text_list.items(), key=lambda d: d[1], reverse=True)) + res_json[k] = sorted(text_list.items(), key=lambda d: d[1], reverse=True)[0][0] + + json.dump(res_json, open(pred_file, "w"), ensure_ascii=False, indent=4) + +def ensemble_v4_multiple(res_list, pred_file, weight): + """ + res_list: 单模型输出的json + 对模型在test1上F1的权重w做下列计算作为输入的weight + weight = (w-50) / (w.max()) + input_files = [2, 3, 6, 7] + """ + res_json = {} + try: + for k in list(res_list[0].keys()): + text_list = {} + for i in range(len(res_list)): + for j in range(len(res_list[i][k])): + text = res_list[i][k][j]["text"] + prob = res_list[i][k][j]["probability"] + if not text_list.get(text): + text_list[text] = weight[i] * prob # 1 * prob + else: + text_list[text] += weight[i] * prob # 1 * prob + res_json[k] = sorted(text_list.items(), key=lambda d: d[1], reverse=True)[0][0] +# json.dump(res_json, open(pred_file, "w"), ensure_ascii=False, indent=4) + return res_json + except Exception as e1: + print("err1: ", pred_file, e1) + +def get_json_dict(division, model_lst=[]): + json_dict = {} + if len(model_lst) != 0: # 只load指定文件 + for item in model_lst: + model_type, param = item.split('/') + obj = json.load(open(datapath+'{}_{}_nbest_predictions_utf8.json'.format(item, division))) + json_dict[model_type+'/'+ param] = (obj, single_dict[model_type][param]) + return json_dict + + # load全部文件 + for model_type, param_dict in single_dict.items(): + for param, f1 in param_dict.items(): + obj = json.load(open(datapath+'{}/{}_{}_nbest_predictions_utf8.json'.format(model_type, param, division))) + json_dict[model_type+'/'+ param] = (obj, f1) + return json_dict + + +def ensemble_one(model_lst, division, version, weight_all=None): + obj_dict = get_json_dict(division, model_lst) + res_list = [obj_dict[key][0] for key in model_lst] + if version == 'v4': + f1_np = np.array([obj_dict[key][1] for key in model_lst]) + f1_np = (f1_np - weight_all) / f1_np.max() + + if version == 'v4': + file_name = '_'.join(model_lst) + '_w' + str(weight_all) + '.json' + pred_ans = ensemble_v4_multiple(res_list, file_name, f1_np) + elif version == 'v1' and weight_all == None: + file_name = '_'.join(model_lst) + '.json' + pred_ans = ensemble_v1_multiple(res_list, file_name) + else: + input('version input error! version=', version) + + sim_name = [] + for item in model_lst: + if 'xxlarge' in item: + sim_name.append('xx'+item.split('/')[1]) + elif 'output_roberta_utf8' in item: + sim_name.append('ro'+item.split('/')[1]) + elif 'output_data_join_utf8' in item: + sim_name.append('dj'+item.split('/')[1]) + else: + sim_name.append(item.split('/')[1]) + if version == 'v4': + file_name = '-'.join(sim_name) + '_w' + str(weight_all) + '.json' + elif version == 'v1': + file_name = '-'.join(sim_name) + '.json' + else: + input('version input error! version=', version) + file_name = '{}_ensemble_{}_'.format(division, version) + file_name + print('file_name: ', file_name) + json.dump(pred_ans, open(os.path.join(export_datapath, file_name), "w"), ensure_ascii=False, indent=4) + +if __name__ == "__main__": + ''' + 参数: group datapath export_datapath constant division + ''' + #group = {'output_data_join_utf8': ["26_2", "14_3", "23_3", "26_5", "17_2", "23_4", "27_1", "27_2", "17_3", "23_5", "27_5", "22_5", "25_4", "28_5", "23_2", "26_1"]} # 按照single_dict的模型类型,指定参与ensemble的模型名称。原始json文件名类似:14_3_test1_nbest_predictions_utf8.json + + L = ["14","17","21","22","23","25","26","27","28","29","33"] + group = {'output_data_join_utf8':[]} + for i in range(1,6): + for l in L: + group['output_data_join_utf8'] += [l+"_"+str(i)] + + datapath = './results/' # 输出的ensemble json存放路径 + export_datapath = '/home_export/bzw/MRC/code/lic2020/results/ensemble_test1'# 新ensemble json保存目录 + #constant = 71.5 # constant + division = 'test1' # test1/train/dev + version = 'v1' # ensemble version + + + model_name_list = [] + for name in group: + for sub in group[name]: + model_name_list.append(name+'/'+sub) + model_name_list = sorted(model_name_list) + print(model_name_list) + ensemble_one(model_name_list, division, version) + + diff --git a/move_model.py b/move_model.py new file mode 100644 index 0000000..37d790f --- /dev/null +++ b/move_model.py @@ -0,0 +1,36 @@ +import os +import sys +model_type = 'output_data_join_utf8' +versions = [sys.argv[1]] + +source_data = '/home_export/lat/%s/' % model_type +to_data = '/home_export/bzw/MRC/code/lic2020/results/%s/' % model_type + +cmd_list = [] +files = sorted((os.listdir(source_data))) +for file in files: + for version in versions: + if version in file and 'nbest_predictions' in file: + print(file) + + #if 'test1' in file: + # division = 'test1' + if 'dev' in file: + division = 'dev' + #elif 'train' in file: + # division = 'train' + else: + continue + input('error! ') + new_file = '{}_{}_nbest_predictions_utf8.json'.format(version, division) + cmd = 'cp {} {}'.format(source_data+file, to_data+new_file) + #cmd = 'scp {} bzw@10.108.218.217:{}'.format(source_data+file, to_data+new_file) + print(cmd) + flag = input('move?(空字符表示确认):') + if flag != '': + print('don\'t move! ') + else: +# cmd_list.append(cmd) + pro = os.popen(cmd) + text = pro.read() + print(text.strip()) diff --git a/split_file_to_directory.py b/split_file_to_directory.py new file mode 100644 index 0000000..f91a832 --- /dev/null +++ b/split_file_to_directory.py @@ -0,0 +1,35 @@ +import sys +import json +import os +file_name = sys.argv[1] +new_directory = sys.argv[2] + + +def deal(data): + + f = data["id"] + + start_logits = {} + end_logits = {} + + tokens = data["tokens"].split() + assert len(tokens) == len(data["start_logits"]) + for i,logit in enumerate(data["start_logits"]): + + start_logits[str(i)] = [data["span_id"][i],tokens[i],logit] + end_logits[str(i)] = data["end_logits"][i] + + return f, {"start_logits":start_logits,"end_logits":end_logits,"ori_tokens":data["ori_tokens"]} + +if __name__ == "__main__": + directory = os.path.join(new_directory,file_name.split("/")[-1].split(".")[0]) + print (directory) + os.mkdir(directory) + for line in open(file_name): + line = line.strip() + data = json.loads(line) + f,data = deal(data) + + open(os.path.join(directory,f),"w").write(json.dumps(data)) + +