diff --git a/ensemble_by_prob_multi_file_multiepoch.py b/ensemble_by_prob_multi_file_multiepoch.py new file mode 100644 index 0000000..f92bf84 --- /dev/null +++ b/ensemble_by_prob_multi_file_multiepoch.py @@ -0,0 +1,360 @@ +import json +import sys +import collections +from transformers.tokenization_bert import BasicTokenizer +import logging +import numpy as np +import os +from tqdm import tqdm +import math +import glob +#logging = logging.getLogger(__name__) +from functools import partial +from multiprocessing import Pool, cpu_count +from numba import jit + +def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): + """Project the tokenized prediction back to the original text.""" + + # When we created the data, we kept track of the alignment between original + # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So + # now `orig_text` contains the span of our original text corresponding to the + # span that we predicted. + # + # However, `orig_text` may contain extra characters that we don't want in + # our prediction. + # + # For example, let's say: + # pred_text = steve smith + # orig_text = Steve Smith's + # + # We don't want to return `orig_text` because it contains the extra "'s". + # + # We don't want to return `pred_text` because it's already been normalized + # (the SQuAD eval script also does punctuation stripping/lower casing but + # our tokenizer does additional normalization like stripping accent + # characters). + # + # What we really want to return is "Steve Smith". + # + # Therefore, we have to apply a semi-complicated alignment heuristic between + # `pred_text` and `orig_text` to get a character-to-character alignment. This + # can fail in certain cases in which case we just return `orig_text`. + + def _strip_spaces(text): + ns_chars = [] + ns_to_s_map = collections.OrderedDict() + for (i, c) in enumerate(text): + if c == " ": + continue + ns_to_s_map[len(ns_chars)] = i + ns_chars.append(c) + ns_text = "".join(ns_chars) + return (ns_text, ns_to_s_map) + + # We first tokenize `orig_text`, strip whitespace from the result + # and `pred_text`, and check if they are the same length. If they are + # NOT the same length, the heuristic has failed. If they are the same + # length, we assume the characters are one-to-one aligned. + tokenizer = BasicTokenizer(do_lower_case=True) + + tok_text = " ".join(tokenizer.tokenize(orig_text)) + + start_position = tok_text.find(pred_text) + if start_position == -1: + if verbose_logging: + print ("=="*10) + print (tok_text) + print("Unable to find text: '%s' in '%s'" % (pred_text, orig_text)) + #return orig_text,0,len(orig_text) + return pred_text.replace(" ",""),0,len(orig_text) + end_position = start_position + len(pred_text) - 1 + + (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text) + (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text) + + if len(orig_ns_text) != len(tok_ns_text): + if verbose_logging: + print("Length not equal after stripping spaces: '%s' vs '%s'", orig_ns_text, tok_ns_text) + #return orig_text,0,len(orig_text) + return pred_text.replace(" ",""),0,len(orig_text) + + # We then project the characters in `pred_text` back to `orig_text` using + # the character-to-character alignment. + tok_s_to_ns_map = {} + for (i, tok_index) in tok_ns_to_s_map.items(): + tok_s_to_ns_map[tok_index] = i + + orig_start_position = None + if start_position in tok_s_to_ns_map: + ns_start_position = tok_s_to_ns_map[start_position] + if ns_start_position in orig_ns_to_s_map: + orig_start_position = orig_ns_to_s_map[ns_start_position] + + if orig_start_position is None: + if verbose_logging: + print("Couldn't map start position") + #return orig_text,0,len(orig_text) + return pred_text.replace(" ",""),0,len(orig_text) + + orig_end_position = None + if end_position in tok_s_to_ns_map: + ns_end_position = tok_s_to_ns_map[end_position] + if ns_end_position in orig_ns_to_s_map: + orig_end_position = orig_ns_to_s_map[ns_end_position] + + if orig_end_position is None: + if verbose_logging: + print("Couldn't map end position") + #return orig_text,0,len(orig_text) + return pred_text.replace(" ",""),0,len(orig_text) + + + output_text = orig_text[orig_start_position : (orig_end_position + 1)] + return output_text,orig_start_position,orig_end_position + 1 +def _get_best_start_indexes(logits, n_best_size): + """Get the n-best logits from a list.""" + index_and_score = sorted(logits.items(), key=lambda x: x[1][2], reverse=True) + + best_indexes = [] + for i in range(len(index_and_score)): + if i >= n_best_size: + break + best_indexes.append(index_and_score[i][0]) + return best_indexes + +def _get_best_end_indexes(logits, n_best_size): + """Get the n-best logits from a list.""" + index_and_score = sorted(logits.items(), key=lambda x: x[1], reverse=True) + + best_indexes = [] + for i in range(len(index_and_score)): + if i >= n_best_size: + break + best_indexes.append(index_and_score[i][0]) + return best_indexes + +def _compute_softmax(scores): + """Compute softmax probability over raw logits.""" + if not scores: + return [] + + max_score = None + for score in scores: + if max_score is None or score > max_score: + max_score = score + + exp_scores = [] + total_sum = 0.0 + for score in scores: + x = math.exp(score - max_score) + exp_scores.append(x) + total_sum += x + + probs = [] + for score in exp_scores: + probs.append(score / total_sum) + return probs + + +def extract_answer(info,topk=20,max_answer_length=30): + + start_logits = info["start_logits"] + end_logits = info["end_logits"] + + tokens = [start_logits[str(i)][1] for i in range(len(start_logits))] + + start_indexes = _get_best_start_indexes(start_logits,topk) + end_indexes = _get_best_end_indexes(end_logits, topk) + + _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name + "PrelimPrediction", ["start_index", "end_index", "start_logit", "end_logit"] + ) + + prelim_predictions = [] + + for start_index in start_indexes: + for end_index in end_indexes: + if int(end_index) < int(start_index): + continue + length = int(end_index) - int(start_index) + 1 + if length >max_answer_length: + continue + prelim_predictions.append( + _PrelimPrediction( + start_index=start_index, + end_index=end_index, + start_logit=start_logits[start_index][2], + end_logit=end_logits[end_index], + ) + ) + + prelim_predictions = sorted(prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True) + _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name + "NbestPrediction", ["text", "start_logit", "end_logit","start_index", "end_index"] + ) + nbest = [] + seen_predictions = {} + final_text = "" + for pred in prelim_predictions: + if len(nbest) >= topk: + break + tok_tokens = tokens[int(pred.start_index):int(pred.end_index)+1] + orig_tokens = info["ori_tokens"][start_logits[pred.start_index][0]:start_logits[pred.end_index][0] + 1] + + tok_text = " ".join(tok_tokens) + tok_text = tok_text.replace(" ##", "") + tok_text = tok_text.replace("##", "") + + tok_text = tok_text.strip() + tok_text = " ".join(tok_text.split()) + orig_text = " ".join(orig_tokens) + + final_text,start_index,end_index = get_final_text(tok_text, orig_text, do_lower_case=False, verbose_logging=False) + + if final_text in seen_predictions: + continue + seen_predictions[final_text] = True + + nbest.append(_NbestPrediction(text=final_text, start_logit=pred.start_logit, end_logit=pred.end_logit,start_index = start_index,end_index = end_index)) + + if not nbest: + nbest.append(_NbestPrediction(text="", start_logit=-1e6, end_logit=-1e6,start_index = 0, end_index = 0)) + + total_scores = [] + best_non_null_entry = None + + + for entry in nbest: + total_scores.append(entry.start_logit + entry.end_logit) + if not best_non_null_entry: + best_non_null_entry = entry + + probs = _compute_softmax(total_scores) + + nbest_json = [] + for (i, entry) in enumerate(nbest): + output = collections.OrderedDict() + output["text"] = entry.text + output["probability"] = probs[i] + output["start_logit"] = entry.start_logit + output["end_logit"] = entry.end_logit + output["start_index"] = entry.start_index + output["end_index"] = entry.end_index + + nbest_json.append(output) + + assert len(nbest_json) >= 1 + assert best_non_null_entry is not None + + return best_non_null_entry.text, nbest_json + + +@jit +def softmax(x): + x_row_max = np.max(x,axis=-1) + x_row_max = x_row_max.reshape(list(x.shape)[:-1]+[1]) + x = x - x_row_max + x_exp = np.exp(x) + x_exp_row_sum = x_exp.sum(axis=-1).reshape(list(x.shape)[:-1]+[1]) + softmax = x_exp / x_exp_row_sum + return softmax + +@jit("f8(f8[:])", cache=False, nopython=True, nogil=True, parallel=True) +def esum(z): + return np.sum(np.exp(z)) + + +@jit("f8[:](f8[:])", cache=False, nopython=True, nogil=True, parallel=True) +def softmax_bak(z): + num = np.exp(z) + s = num / esum(z) + return s + + +def ensemble_logits(file_list,filename): + + data_list = [json.load(open(os.path.join(f,filename))) for f in file_list] + data_new = data_list[0] + + for i in range(len(file_list)): + + logit = [data_list[i]["start_logits"][str(index)][2] for index in range(len(data_list[i]["start_logits"]))] + logit = np.array(logit) + probs = softmax(logit) +# assert sum(probs) == 1 + + for index in range(len(data_list[i]["start_logits"])): + data_list[i]["start_logits"][str(index)][2] = probs[index] + + logit = [data_list[i]["end_logits"][str(index)] for index in range(len(data_list[i]["end_logits"]))] + + logit = np.array(logit) + probs = softmax(logit) + for index in range(len(data_list[i]["end_logits"])): + data_list[i]["end_logits"][str(index)] = probs[index] + + + data_new = data_list[0] + + for i in range(1,len(file_list)): + assert data_list[i]["ori_tokens"] == data_new["ori_tokens"] + for index in range(len(data_new["start_logits"])): + assert data_new["start_logits"][str(index)][1] == data_list[i]["start_logits"][str(index)][1] + data_new["start_logits"][str(index)][2] += data_list[i]["start_logits"][str(index)][2] + data_new["end_logits"][str(index)] += data_list[i]["end_logits"][str(index)] + return data_new + +def deal_one_file(filename,ensemble_data_files): + data = ensemble_logits(file_list = ensemble_data_files, filename = filename) + answer,nbest = extract_answer(data) + return filename,answer,nbest + +def multiprocess(ensemble_data_files,filenames,threads=5): + with Pool(threads) as p: + annotate_ = partial( + deal_one_file, + ensemble_data_files=ensemble_data_files, + ) + features = list( + tqdm( + p.imap(annotate_, filenames, chunksize=4), + total=len(filenames), + desc="generate answer and nbest", + ) + ) + result = collections.OrderedDict() + nbest_result = collections.OrderedDict() + for filename, answer, nbest in tqdm(features): + result[filename] = answer + nbest_result[filename] = nbest + return result, nbest_result +if __name__ == "__main__": + ensemble_data_files = [] + versions = ["14","17","21","22","23","25","26","27","28","29","33","34","35","36","37","38","39"] + #versions = ["14","17","21","22","23","25","26","27","28","29","33","34","35","36","37","38"] + #versions = ["14"] + _type = "test1" + logit_path = "results/test1_logit/" + results_path = "results/ensemble_test1/" + EP = "EP4-5" + for version in versions: + for epoch in range(4,6): + v = "{}_{}*{}*".format(version,epoch,_type) + try: + filepath = glob.glob(os.path.join(logit_path,v))[0] + ensemble_data_files.append(filepath) + except: + print ("cant find:", v) + pass + #exit() + print ( ensemble_data_files) + result = collections.OrderedDict() + nbest_result = collections.OrderedDict() + for filename in tqdm(os.listdir(ensemble_data_files[0])): + data = ensemble_logits(ensemble_data_files,filename) + result[filename],nbest_result[filename] = extract_answer(data,topk=20,max_answer_length=30) + #result, nbest_result = multiprocess(ensemble_data_files,os.listdir(ensemble_data_files[0])) + name = "_".join(versions)+"_"+EP+"_nbest20_len30.json" + json.dump(result,open(os.path.join(results_path,name),"w"),ensure_ascii=False,indent=4) + json.dump(nbest_result,open(os.path.join(results_path,name) + '.nbest',"w"),ensure_ascii=False,indent=4) diff --git a/ensemble_by_prob_multi_file_multiepoch_speed.py b/ensemble_by_prob_multi_file_multiepoch_speed.py new file mode 100644 index 0000000..2ca6a45 --- /dev/null +++ b/ensemble_by_prob_multi_file_multiepoch_speed.py @@ -0,0 +1,342 @@ +import json +import sys +import collections +from transformers.tokenization_bert import BasicTokenizer +import logging +import numpy as np +import os +from tqdm import tqdm +import math +import glob +#logging = logging.getLogger(__name__) +from functools import partial +from multiprocessing import Pool, cpu_count +from numba import jit + +def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): + """Project the tokenized prediction back to the original text.""" + + # When we created the data, we kept track of the alignment between original + # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So + # now `orig_text` contains the span of our original text corresponding to the + # span that we predicted. + # + # However, `orig_text` may contain extra characters that we don't want in + # our prediction. + # + # For example, let's say: + # pred_text = steve smith + # orig_text = Steve Smith's + # + # We don't want to return `orig_text` because it contains the extra "'s". + # + # We don't want to return `pred_text` because it's already been normalized + # (the SQuAD eval script also does punctuation stripping/lower casing but + # our tokenizer does additional normalization like stripping accent + # characters). + # + # What we really want to return is "Steve Smith". + # + # Therefore, we have to apply a semi-complicated alignment heuristic between + # `pred_text` and `orig_text` to get a character-to-character alignment. This + # can fail in certain cases in which case we just return `orig_text`. + + def _strip_spaces(text): + ns_chars = [] + ns_to_s_map = collections.OrderedDict() + for (i, c) in enumerate(text): + if c == " ": + continue + ns_to_s_map[len(ns_chars)] = i + ns_chars.append(c) + ns_text = "".join(ns_chars) + return (ns_text, ns_to_s_map) + + # We first tokenize `orig_text`, strip whitespace from the result + # and `pred_text`, and check if they are the same length. If they are + # NOT the same length, the heuristic has failed. If they are the same + # length, we assume the characters are one-to-one aligned. + tokenizer = BasicTokenizer(do_lower_case=True) + + tok_text = " ".join(tokenizer.tokenize(orig_text)) + + start_position = tok_text.find(pred_text) + if start_position == -1: + if verbose_logging: + print ("=="*10) + print (tok_text) + print("Unable to find text: '%s' in '%s'" % (pred_text, orig_text)) + #return orig_text,0,len(orig_text) + return pred_text.replace(" ",""),0,len(orig_text) + end_position = start_position + len(pred_text) - 1 + + (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text) + (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text) + + if len(orig_ns_text) != len(tok_ns_text): + if verbose_logging: + print("Length not equal after stripping spaces: '%s' vs '%s'", orig_ns_text, tok_ns_text) + #return orig_text,0,len(orig_text) + return pred_text.replace(" ",""),0,len(orig_text) + + # We then project the characters in `pred_text` back to `orig_text` using + # the character-to-character alignment. + tok_s_to_ns_map = {} + for (i, tok_index) in tok_ns_to_s_map.items(): + tok_s_to_ns_map[tok_index] = i + + orig_start_position = None + if start_position in tok_s_to_ns_map: + ns_start_position = tok_s_to_ns_map[start_position] + if ns_start_position in orig_ns_to_s_map: + orig_start_position = orig_ns_to_s_map[ns_start_position] + + if orig_start_position is None: + if verbose_logging: + print("Couldn't map start position") + #return orig_text,0,len(orig_text) + return pred_text.replace(" ",""),0,len(orig_text) + + orig_end_position = None + if end_position in tok_s_to_ns_map: + ns_end_position = tok_s_to_ns_map[end_position] + if ns_end_position in orig_ns_to_s_map: + orig_end_position = orig_ns_to_s_map[ns_end_position] + + if orig_end_position is None: + if verbose_logging: + print("Couldn't map end position") + #return orig_text,0,len(orig_text) + return pred_text.replace(" ",""),0,len(orig_text) + + + output_text = orig_text[orig_start_position : (orig_end_position + 1)] + return output_text,orig_start_position,orig_end_position + 1 +def _get_best_start_indexes(logits, n_best_size): + """Get the n-best logits from a list.""" + index_and_score = sorted(logits.items(), key=lambda x: x[1][2], reverse=True) + + best_indexes = [] + for i in range(len(index_and_score)): + if i >= n_best_size: + break + best_indexes.append(index_and_score[i][0]) + return best_indexes + +def _get_best_end_indexes(logits, n_best_size): + """Get the n-best logits from a list.""" + index_and_score = sorted(logits.items(), key=lambda x: x[1], reverse=True) + + best_indexes = [] + for i in range(len(index_and_score)): + if i >= n_best_size: + break + best_indexes.append(index_and_score[i][0]) + return best_indexes + +def _compute_softmax(scores): + """Compute softmax probability over raw logits.""" + if not scores: + return [] + + max_score = None + for score in scores: + if max_score is None or score > max_score: + max_score = score + + exp_scores = [] + total_sum = 0.0 + for score in scores: + x = math.exp(score - max_score) + exp_scores.append(x) + total_sum += x + + probs = [] + for score in exp_scores: + probs.append(score / total_sum) + return probs + + +def extract_answer(info,topk=20,max_answer_length=30): + + start_logits = info["start_logits"] + end_logits = info["end_logits"] + + tokens = [start_logits[str(i)][1] for i in range(len(start_logits))] + + start_indexes = _get_best_start_indexes(start_logits,topk) + end_indexes = _get_best_end_indexes(end_logits, topk) + + _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name + "PrelimPrediction", ["start_index", "end_index", "start_logit", "end_logit"] + ) + + prelim_predictions = [] + + for start_index in start_indexes: + for end_index in end_indexes: + if int(end_index) < int(start_index): + continue + length = int(end_index) - int(start_index) + 1 + if length >max_answer_length: + continue + prelim_predictions.append( + _PrelimPrediction( + start_index=start_index, + end_index=end_index, + start_logit=start_logits[start_index][2], + end_logit=end_logits[end_index], + ) + ) + + prelim_predictions = sorted(prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True) + _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name + "NbestPrediction", ["text", "start_logit", "end_logit","start_index", "end_index"] + ) + nbest = [] + seen_predictions = {} + final_text = "" + for pred in prelim_predictions: + if len(nbest) >= topk: + break + tok_tokens = tokens[int(pred.start_index):int(pred.end_index)+1] + orig_tokens = info["ori_tokens"][start_logits[pred.start_index][0]:start_logits[pred.end_index][0] + 1] + + tok_text = " ".join(tok_tokens) + tok_text = tok_text.replace(" ##", "") + tok_text = tok_text.replace("##", "") + + tok_text = tok_text.strip() + tok_text = " ".join(tok_text.split()) + orig_text = " ".join(orig_tokens) + + final_text,start_index,end_index = get_final_text(tok_text, orig_text, do_lower_case=False, verbose_logging=False) + + if final_text in seen_predictions: + continue + seen_predictions[final_text] = True + + nbest.append(_NbestPrediction(text=final_text, start_logit=pred.start_logit, end_logit=pred.end_logit,start_index = start_index,end_index = end_index)) + + if not nbest: + nbest.append(_NbestPrediction(text="", start_logit=-1e6, end_logit=-1e6,start_index = 0, end_index = 0)) + + total_scores = [] + best_non_null_entry = None + + + for entry in nbest: + total_scores.append(entry.start_logit + entry.end_logit) + if not best_non_null_entry: + best_non_null_entry = entry + + probs = _compute_softmax(total_scores) + + nbest_json = [] + for (i, entry) in enumerate(nbest): + output = collections.OrderedDict() + output["text"] = entry.text + output["probability"] = probs[i] + output["start_logit"] = entry.start_logit + output["end_logit"] = entry.end_logit + output["start_index"] = entry.start_index + output["end_index"] = entry.end_index + + nbest_json.append(output) + + assert len(nbest_json) >= 1 + assert best_non_null_entry is not None + + return best_non_null_entry.text, nbest_json + + +@jit +def softmax(x): + x_row_max = np.max(x,axis=-1) + x_row_max = x_row_max.reshape(list(x.shape)[:-1]+[1]) + x = x - x_row_max + x_exp = np.exp(x) + x_exp_row_sum = x_exp.sum(axis=-1).reshape(list(x.shape)[:-1]+[1]) + softmax = x_exp / x_exp_row_sum + return softmax + +@jit("f8(f8[:])", cache=False, nopython=True, nogil=True, parallel=True) +def esum(z): + return np.sum(np.exp(z)) + + +@jit("f8[:](f8[:])", cache=False, nopython=True, nogil=True, parallel=True) +def softmax_bak(z): + num = np.exp(z) + s = num / esum(z) + return s + + +def ensemble_logits(file_list,filename): + + data_list = [json.load(open(os.path.join(f,filename))) for f in file_list] + data_new = data_list[0] + + data_new = data_list[0] + data_new["start_logits"] = softmax(np.array(data_new["start_logits"])) + data_new["end_logits"] = softmax(np.array(data_new["end_logits"])) + + for i in range(1,len(file_list)): + assert data_list[i]["ori_tokens"] == data_new["ori_tokens"] + data_new["start_logits"] += softmax(np.array(data_list[i]["start_logits"])) + data_new["end_logits"][str(index)] += softmax(np.array(data_list[i]["end_logits"][str(index)])) + return data_new + +def deal_one_file(filename,ensemble_data_files): + data = ensemble_logits(file_list = ensemble_data_files, filename = filename) + answer,nbest = extract_answer(data) + return filename,answer,nbest + +def multiprocess(ensemble_data_files,filenames,threads=5): + with Pool(threads) as p: + annotate_ = partial( + deal_one_file, + ensemble_data_files=ensemble_data_files, + ) + features = list( + tqdm( + p.imap(annotate_, filenames, chunksize=4), + total=len(filenames), + desc="generate answer and nbest", + ) + ) + result = collections.OrderedDict() + nbest_result = collections.OrderedDict() + for filename, answer, nbest in tqdm(features): + result[filename] = answer + nbest_result[filename] = nbest + return result, nbest_result +if __name__ == "__main__": + ensemble_data_files = [] + versions = ["14","17","21","22","23","25","26","27","28","29","33","34","35","36","37","38","39"] + #versions = ["14","17","21","22","23","25","26","27","28","29","33","34","35","36","37","38"] + #versions = ["14"] + _type = "test1" + logit_path = "results/test1_logit/" + results_path = "results/ensemble_test1/" + EP = "EP4-5" + for version in versions: + for epoch in range(4,6): + v = "{}_{}*{}*".format(version,epoch,_type) + try: + filepath = glob.glob(os.path.join(logit_path,v))[0] + ensemble_data_files.append(filepath) + except: + print ("cant find:", v) + pass + #exit() + print ( ensemble_data_files) + result = collections.OrderedDict() + nbest_result = collections.OrderedDict() + for filename in tqdm(os.listdir(ensemble_data_files[0])): + data = ensemble_logits(ensemble_data_files,filename) + result[filename],nbest_result[filename] = extract_answer(data,topk=20,max_answer_length=30) + #result, nbest_result = multiprocess(ensemble_data_files,os.listdir(ensemble_data_files[0])) + name = "_".join(versions)+"_"+EP+"_nbest20_len30.json" + json.dump(result,open(os.path.join(results_path,name),"w"),ensure_ascii=False,indent=4) + json.dump(nbest_result,open(os.path.join(results_path,name) + '.nbest',"w"),ensure_ascii=False,indent=4) diff --git a/ensemble_by_prob_multi_file_speed.py b/ensemble_by_prob_multi_file_speed.py new file mode 100644 index 0000000..e43d960 --- /dev/null +++ b/ensemble_by_prob_multi_file_speed.py @@ -0,0 +1,304 @@ +import json +import sys +import collections +from transformers.tokenization_bert import BasicTokenizer +import logging +import numpy as np +import os +from tqdm import tqdm +import math +#logging = logging.getLogger(__name__) + +def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): + """Project the tokenized prediction back to the original text.""" + + # When we created the data, we kept track of the alignment between original + # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So + # now `orig_text` contains the span of our original text corresponding to the + # span that we predicted. + # + # However, `orig_text` may contain extra characters that we don't want in + # our prediction. + # + # For example, let's say: + # pred_text = steve smith + # orig_text = Steve Smith's + # + # We don't want to return `orig_text` because it contains the extra "'s". + # + # We don't want to return `pred_text` because it's already been normalized + # (the SQuAD eval script also does punctuation stripping/lower casing but + # our tokenizer does additional normalization like stripping accent + # characters). + # + # What we really want to return is "Steve Smith". + # + # Therefore, we have to apply a semi-complicated alignment heuristic between + # `pred_text` and `orig_text` to get a character-to-character alignment. This + # can fail in certain cases in which case we just return `orig_text`. + + def _strip_spaces(text): + ns_chars = [] + ns_to_s_map = collections.OrderedDict() + for (i, c) in enumerate(text): + if c == " ": + continue + ns_to_s_map[len(ns_chars)] = i + ns_chars.append(c) + ns_text = "".join(ns_chars) + return (ns_text, ns_to_s_map) + + # We first tokenize `orig_text`, strip whitespace from the result + # and `pred_text`, and check if they are the same length. If they are + # NOT the same length, the heuristic has failed. If they are the same + # length, we assume the characters are one-to-one aligned. + tokenizer = BasicTokenizer(do_lower_case=True) + + tok_text = " ".join(tokenizer.tokenize(orig_text)) + + start_position = tok_text.find(pred_text) + if start_position == -1: + if verbose_logging: + print ("=="*10) + print (tok_text) + print("Unable to find text: '%s' in '%s'" % (pred_text, orig_text)) + #return orig_text,0,len(orig_text) + return pred_text.replace(" ",""),0,len(orig_text) + end_position = start_position + len(pred_text) - 1 + + (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text) + (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text) + + if len(orig_ns_text) != len(tok_ns_text): + if verbose_logging: + print("Length not equal after stripping spaces: '%s' vs '%s'", orig_ns_text, tok_ns_text) + #return orig_text,0,len(orig_text) + return pred_text.replace(" ",""),0,len(orig_text) + + # We then project the characters in `pred_text` back to `orig_text` using + # the character-to-character alignment. + tok_s_to_ns_map = {} + for (i, tok_index) in tok_ns_to_s_map.items(): + tok_s_to_ns_map[tok_index] = i + + orig_start_position = None + if start_position in tok_s_to_ns_map: + ns_start_position = tok_s_to_ns_map[start_position] + if ns_start_position in orig_ns_to_s_map: + orig_start_position = orig_ns_to_s_map[ns_start_position] + + if orig_start_position is None: + if verbose_logging: + print("Couldn't map start position") + #return orig_text,0,len(orig_text) + return pred_text.replace(" ",""),0,len(orig_text) + + orig_end_position = None + if end_position in tok_s_to_ns_map: + ns_end_position = tok_s_to_ns_map[end_position] + if ns_end_position in orig_ns_to_s_map: + orig_end_position = orig_ns_to_s_map[ns_end_position] + + if orig_end_position is None: + if verbose_logging: + print("Couldn't map end position") + #return orig_text,0,len(orig_text) + return pred_text.replace(" ",""),0,len(orig_text) + + + output_text = orig_text[orig_start_position : (orig_end_position + 1)] + return output_text,orig_start_position,orig_end_position + 1 +def _get_best_start_indexes(logits, n_best_size): + """Get the n-best logits from a list.""" + index_and_score = sorted(logits.items(), key=lambda x: x[1][2], reverse=True) + + best_indexes = [] + for i in range(len(index_and_score)): + if i >= n_best_size: + break + best_indexes.append(index_and_score[i][0]) + return best_indexes + +def _get_best_indexes(logits, n_best_size): + """Get the n-best logits from a list.""" + index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True) + + best_indexes = [] + for i in range(len(index_and_score)): + if i >= n_best_size: + break + best_indexes.append(index_and_score[i][0]) + return best_indexes + +def _compute_softmax(scores): + """Compute softmax probability over raw logits.""" + if not scores: + return [] + + max_score = None + for score in scores: + if max_score is None or score > max_score: + max_score = score + + exp_scores = [] + total_sum = 0.0 + for score in scores: + x = math.exp(score - max_score) + exp_scores.append(x) + total_sum += x + + probs = [] + for score in exp_scores: + probs.append(score / total_sum) + return probs + + +def extract_answer(info): + + start_logits = info["start_logits"] + end_logits = info["end_logits"] + + #tokens = [start_logits[str(i)][1] for i in range(len(start_logits))] + tokens = info["tokens"].split() + span_ids = info["span_ids"] + + start_indexes = _get_best_indexes(start_logits, 20) + end_indexes = _get_best_indexes(end_logits, 20) + + _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name + "PrelimPrediction", ["start_index", "end_index", "start_logit", "end_logit"] + ) + + prelim_predictions = [] + + for start_index in start_indexes: + for end_index in end_indexes: + if int(end_index) < int(start_index): + continue + length = int(end_index) - int(start_index) + 1 + if length >30: + continue + prelim_predictions.append( + _PrelimPrediction( + start_index=start_index, + end_index=end_index, + start_logit=start_logits[start_index], + end_logit=end_logits[end_index], + ) + ) + + prelim_predictions = sorted(prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True) + _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name + "NbestPrediction", ["text", "start_logit", "end_logit","start_index", "end_index"] + ) + nbest = [] + seen_predictions = {} + final_text = "" + for pred in prelim_predictions: + if len(nbest) >= 10: + break + tok_tokens = tokens[int(pred.start_index):int(pred.end_index)+1] + orig_tokens = info["ori_tokens"][span_ids[pred.start_index]:span_ids[pred.end_index] + 1] + + tok_text = " ".join(tok_tokens) + tok_text = tok_text.replace(" ##", "") + tok_text = tok_text.replace("##", "") + + tok_text = tok_text.strip() + tok_text = " ".join(tok_text.split()) + orig_text = " ".join(orig_tokens) + + final_text,start_index,end_index = get_final_text(tok_text, orig_text, do_lower_case=False, verbose_logging=False) + + if final_text in seen_predictions: + continue + seen_predictions[final_text] = True + + nbest.append(_NbestPrediction(text=final_text, start_logit=pred.start_logit, end_logit=pred.end_logit,start_index = start_index,end_index = end_index)) + + if not nbest: + nbest.append(_NbestPrediction(text="", start_log_prob=-1e6, end_log_prob=-1e6)) + + total_scores = [] + best_non_null_entry = None + + + for entry in nbest: + total_scores.append(entry.start_logit + entry.end_logit) + if not best_non_null_entry: + best_non_null_entry = entry + + probs = _compute_softmax(total_scores) + + nbest_json = [] + for (i, entry) in enumerate(nbest): + output = collections.OrderedDict() + output["text"] = entry.text + output["probability"] = probs[i] + output["start_logit"] = entry.start_logit + output["end_logit"] = entry.end_logit + output["start_index"] = entry.start_index + output["end_index"] = entry.end_index + + nbest_json.append(output) + + assert len(nbest_json) >= 1 + assert best_non_null_entry is not None + + return best_non_null_entry.text, nbest_json + + +def softmax(x): + x_row_max = x.max(axis=-1) + x_row_max = x_row_max.reshape(list(x.shape)[:-1]+[1]) + x = x - x_row_max + x_exp = np.exp(x) + x_exp_row_sum = x_exp.sum(axis=-1).reshape(list(x.shape)[:-1]+[1]) + softmax = x_exp / x_exp_row_sum + return softmax + +def ensemble_logits(file_list,filename): + + data_list = [json.load(open(os.path.join(f,filename))) for f in file_list] + data_new = data_list[0] + +# for i in range(len(file_list)): +# +# +# logit = [data_list[i]["start_logits"][str(index)][2] for index in range(len(data_list[i]["start_logits"]))] +# logit = np.array(logit) +# probs = softmax(logit) +## assert sum(probs) == 1 +# +# for index in range(len(data_list[i]["start_logits"])): +# data_list[i]["start_logits"][str(index)][2] = probs[index] +# +# logit = [data_list[i]["end_logits"][str(index)] for index in range(len(data_list[i]["end_logits"]))] +# +# logit = np.array(logit) +# probs = softmax(logit) +# for index in range(len(data_list[i]["end_logits"])): +# data_list[i]["end_logits"][str(index)] = probs[index] + + + data_new = data_list[0] + data_new["start_logits"] = softmax(np.array(data_new["start_logits"])) + data_new["end_logits"] = softmax(np.array(data_new["end_logits"])) + + for i in range(1,len(file_list)): + assert data_list[i]["ori_tokens"] == data_new["ori_tokens"] + data_new["start_logits"] += softmax(np.array(data_list[i]["start_logits"])) + data_new["end_logits"][str(index)] += softmax(np.array(data_list[i]["end_logits"][str(index)])) + return data_new + +if __name__ == "__main__": + ensemble_data_files = sys.argv[1:-1] + + result = collections.OrderedDict() + nbest_result = collections.OrderedDict() + for filename in tqdm(os.listdir(ensemble_data_files[0])): + data = ensemble_logits(ensemble_data_files,filename) + result[filename],nbest_result[filename] = extract_answer(data) + + json.dump(result,open(sys.argv[-1],"w"),ensure_ascii=False,indent=4) + json.dump(nbest_result,open(sys.argv[-1] + '.nest',"w"),ensure_ascii=False,indent=4) diff --git a/file_to_directory.sh b/file_to_directory.sh index 2128743..5f9c213 100644 --- a/file_to_directory.sh +++ b/file_to_directory.sh @@ -1,13 +1,13 @@ source activate torch -for name in $(ls $1*json) -#for name in $(ls $1) +#for name in $(ls $1*json) +for name in $(ls $1) do -#name=$1$name -#if [[ -f $name && $name != *"json" && $name != *"bz2" ]] -#then +name=$1$name +if [[ -f $name && $name != *"json" && $name != *"bz2" ]] +then echo $name #mv $name $name.json -python ensemble/split_file_to_directory.py $name $1 +#python ensemble/split_file_to_directory.py $name $1 #python ensemble/split_file_to_directory.py $name.json $1 -#fi +fi done diff --git a/replace_long_answer.py b/replace_long_answer.py new file mode 100644 index 0000000..a9dd933 --- /dev/null +++ b/replace_long_answer.py @@ -0,0 +1,15 @@ +import sys +from collections import OrderedDict +import json +from transformers.tokenization_bert import BasicTokenizer +tokenizer = BasicTokenizer(do_lower_case=True) + +data1 = json.load(open(sys.argv[1]),object_pairs_hook=OrderedDict) +data2 = json.load(open(sys.argv[2])) + +for key,value in data1.items(): + value = tokenizer.tokenize(value) + if len(value) >30 and data2[key] != "" and data2[key] != "empty": + data1[key] = data2[key] + +json.dump(data1,open(sys.argv[3],"w"),ensure_ascii=False,indent=4) diff --git a/softmax_logit.py b/softmax_logit.py new file mode 100644 index 0000000..ca08e22 --- /dev/null +++ b/softmax_logit.py @@ -0,0 +1,45 @@ +import sys +import json +import os +import numpy as np +from tqdm import tqdm +def softmax(x): + x_row_max = np.max(x,axis=-1) + x_row_max = x_row_max.reshape(list(x.shape)[:-1]+[1]) + x = x - x_row_max + x_exp = np.exp(x) + x_exp_row_sum = x_exp.sum(axis=-1).reshape(list(x.shape)[:-1]+[1]) + softmax = x_exp / x_exp_row_sum + return softmax + +def softmax_logits(file_dir,file_dir_new,filename): + + data = json.load(open(os.path.join(file_dir,filename))) + + + logit = [data["start_logits"][str(index)][2] for index in range(len(data["start_logits"]))] + logit = np.array(logit) + probs = softmax(logit) + + for index in range(len(data["start_logits"])): + data["start_logits"][str(index)][2] = probs[index] + + logit = [data["end_logits"][str(index)] for index in range(len(data["end_logits"]))] + + logit = np.array(logit) + probs = softmax(logit) + for index in range(len(data["end_logits"])): + data["end_logits"][str(index)] = probs[index] + + json.dump(data,open(os.path.join(file_dir_new,filename),"w"),ensure_ascii=False,indent=4) + + +if __name__ == "__main__": + + for filename in tqdm(os.listdir(sys.argv[1])): + new_dir = os.path.join(sys.argv[2],sys.argv[1].split("/")[-1]) + try: + os.mkdir(new_dir) + except: + pass + softmax_logits(sys.argv[1],new_dir,filename) diff --git a/split_file_to_directory.py b/split_file_to_directory.py index e647100..17e6b69 100644 --- a/split_file_to_directory.py +++ b/split_file_to_directory.py @@ -9,17 +9,17 @@ def deal(data): f = data["id"] - start_logits = {} - end_logits = {} + #start_logits = {} + #end_logits = {} - tokens = data["tokens"].split() - assert len(tokens) == len(data["start_logits"]) - for i,logit in enumerate(data["start_logits"]): + #tokens = data["tokens"].split() + #assert len(tokens) == len(data["start_logits"]) + #for i,logit in enumerate(data["start_logits"]): - start_logits[str(i)] = [data["span_id"][i],tokens[i],logit] - end_logits[str(i)] = data["end_logits"][i] + # start_logits[str(i)] = [data["span_id"][i],tokens[i],logit] + # end_logits[str(i)] = data["end_logits"][i] - return f, {"start_logits":start_logits,"end_logits":end_logits,"ori_tokens":data["ori_tokens"]} + return f, {"start_logits":start_logits,"end_logits":end_logits,"ori_tokens":data["ori_tokens"],"tokens":data["tokens"],"span_id":data["span_id"]} if __name__ == "__main__": directory = os.path.join(new_directory,file_name.split("/")[-1].split(".")[0])