forked from UFAL-DSG/alex
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Normalizing and building the dataset (unfinished)
- Loading branch information
Showing
1 changed file
with
380 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,380 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
|
||
"""Normalize and build the Alex Context NLG Dataset""" | ||
|
||
from __future__ import unicode_literals | ||
from argparse import ArgumentParser, FileType | ||
import sys | ||
import unicodecsv as csv | ||
import re | ||
from collections import defaultdict | ||
|
||
from util import * | ||
import hunspell | ||
|
||
# Start IPdb on error in interactive mode | ||
from tgen.debug import exc_info_hook | ||
sys.excepthook = exc_info_hook | ||
|
||
|
||
spellchecker = hunspell.HunSpell('/usr/share/hunspell/en_US.dic', '/usr/share/hunspell/en_US.aff') | ||
|
||
# normalizing spelling variants and fixing typos | ||
SPELL_FIXES = {'apologise': 'apologize', | ||
'travelling': 'traveling', | ||
'travelled': 'traveled', | ||
'okay': 'ok', | ||
'dir': 'direction', | ||
'undertand': 'understand', | ||
'adn': 'and', | ||
'aoologies': 'apologies', | ||
'ar': 'are', | ||
'directin': 'direction', | ||
'conections': 'connections', | ||
'centrall': 'central', | ||
'fromn': 'from', | ||
'cortandt': 'cortlandt', | ||
'confrim': 'confirm', | ||
'confrirm': 'confirm', | ||
'looing': 'looking', | ||
'altenative': 'alternative', | ||
'alterative': 'alternative', | ||
'fron': 'from', | ||
'arrivive': 'arrive', | ||
'soory': 'sorry', | ||
'grom': 'from', | ||
'shedule': 'schedule', | ||
'connecion': 'connection', | ||
'whre': 'where', | ||
'lne': 'line', | ||
} | ||
|
||
REL_TIMES = {'ten minutes': '0:10', | ||
'quarter an hour': '0:15', | ||
'fifteen minutes': '0:15', | ||
'twenty minutes': '0:20', | ||
'thirty minutes': '0:30', | ||
'half an hour': '0:30'} | ||
|
||
HOURS = ['zero', 'one', 'two', 'three', 'four', 'five', 'six', | ||
'seven', 'eight', 'nine', 'ten', 'eleven', 'twelve'] | ||
|
||
CUSTOM_DICT = set([',', ';', ':', '?', '!', '.', | ||
'ok', | ||
'i\'m', 'i', 'i\'ll', 'i\'ve', 'i\'d', | ||
'alrighty']) | ||
|
||
ALT_VALUES = {'pm': ['afternoon', 'evening'], | ||
'am': ['morning'], | ||
'1:00': ['one o\'clock', 'one', '1'], | ||
'2:00': ['two o\'clock', 'two', '2'], | ||
'3:00': ['three o\'clock', 'three', '3'], | ||
'4:00': ['four o\'clock', 'four', '4'], | ||
'5:00': ['five o\'clock', 'five', '5'], | ||
'6:00': ['six o\'clock', 'six', '6'], | ||
'7:00': ['seven o\'clock', 'seven', '7'], | ||
'8:00': ['eight o\'clock', 'eight', '8'], | ||
'9:00': ['nine o\'clock', 'nine', '9'], | ||
'10:00': ['ten o\'clock', 'ten', '10'], | ||
'11:00': ['eleven o\'clock', 'eleven', '11'], | ||
'12:00': ['twelve o\'clock', 'twelve', '12'], | ||
'0:30': ['half', 'thirty', '30'], | ||
'0:20': ['twenty', '20'], | ||
'0:15': ['quarter', 'fifteen', '15'], | ||
'0:10': ['ten', '10'],} | ||
|
||
def reparse_time(utt): | ||
# relative times | ||
for expr, val in REL_TIMES.iteritems(): | ||
if expr in utt: | ||
return val | ||
# absolute times | ||
for hour in HOURS: | ||
if hour + ' thirty' in utt: | ||
return str(HOURS.index(hour)) + ':30' | ||
elif hour in utt: | ||
return str(HOURS.index(hour)) + ':00' | ||
raise ValueError('Cannot find time value: ' + utt) | ||
|
||
|
||
def reparse_ampm(utt): | ||
if re.search(r'\bmorning\b', utt): | ||
return 'am' | ||
if re.search(r'\b(afternoon|evening)\b', utt): | ||
return 'pm' | ||
if re.search(r'(zero|one|two|three|four|five|six|seven|eight|nine|ten|eleven|twelve|thirty|clock|\*NUMBER) (am|a m)\b', utt): | ||
return 'am' | ||
if re.search(r'\b(pm|p m)\b', utt): | ||
return 'pm' | ||
raise ValueError('Cannot find AMPM value: ' + utt) | ||
|
||
|
||
def convert_parse(data_line): | ||
ret = data_line.abstr_da | ||
ret = ret.replace('*STREET', '*STOP') | ||
ret = re.sub(r'(request\([a-z_]+)="\*[A-Z_]+"\)', r'\1)', ret) | ||
return ret | ||
|
||
def lexicalize_parse(data_line, delex_parse): | ||
"""Lexicalize abstract SLU parse (with values from the lexicalized reply).""" | ||
ret = delex_parse | ||
svs = data_line.slot_values | ||
for slot, value in svs.iteritems(): | ||
delex_value = '*' + slot.upper() | ||
if 'time' in slot and delex_value in ret: | ||
value = reparse_time(data_line.utt) | ||
elif slot == 'ampm' and delex_value in ret: | ||
value = reparse_ampm(data_line.utt) | ||
ret = ret.replace(delex_value, value) | ||
return ret | ||
|
||
|
||
def convert_da(data_line, delex=False): | ||
"""Convert DA into Alex format (and possibly delexicalize).""" | ||
da = data_line.da | ||
if ';' in da: | ||
da = da.split('; ') | ||
else: | ||
da = [data_line.dat + ': ' + da] | ||
|
||
ret = [] | ||
for da_part in da: | ||
dat, slots = da_part.split(': ') | ||
if dat == 'reply': | ||
dat = 'inform' | ||
elif dat == 'apologize': | ||
dat = 'inform_no_match' | ||
elif dat == 'confirm': | ||
dat = 'iconfirm' | ||
slots = [slot.split('=') for slot in slots.split(', ') if 'notfound' not in slot] | ||
if delex: | ||
ret.extend([dat + '(' + slot_name + ('="*' + slot_name.upper() + '"' if val != '?' else '') + ')' | ||
for slot_name, val in slots]) | ||
else: | ||
ret.extend([dat + '(' + slot_name + ('="' + val + '"' if val != '?' else '') + ')' | ||
for slot_name, val in slots]) | ||
return '&'.join(ret) | ||
|
||
def ask(data_line, response_text, question): | ||
printout(data_line.dat + ' ' + data_line.da + "\n", colour=RED) | ||
printout(response_text + "\n", colour=BLUE) | ||
raw_result = raw_input(question) | ||
return raw_result.strip() | ||
|
||
def normalize_response(response_text, data_line): | ||
"""Normalize response text: | ||
* capitalization | ||
* spacing | ||
* final punctuation | ||
* spelling variants | ||
* spellcheck | ||
""" | ||
toks = ' ' + response_text + ' ' # pad with spaces for easy regexes | ||
|
||
# bugfix, 0:00pm -> 12:00pm | ||
toks = re.sub(r'\b0:([0-9][0-9]\s*[ap]m)\b', r'12:\1', toks) | ||
|
||
# find out data items, store them to exclude from spell checking | ||
data_toks = set() | ||
for data_item in data_line.values_list: | ||
data_toks.update(data_item.lower().split(' ')) | ||
|
||
# tokenize | ||
toks = re.sub(r'([?.!;,:-]+)(?![0-9])', r' \1 ', toks) # enforce space around all punct | ||
toks = re.sub(r'\s+', ' ', toks) | ||
|
||
if 'am' in data_line.values_list or 'pm' in data_line.values_list: | ||
toks = re.sub(r'([0-9])([apAP][mM])\b', r'\1 \2', toks) # separate time & am/pm (acc. to DA) | ||
|
||
toks = toks.lower() # work out spelling in lowercase | ||
|
||
# spelling fixes | ||
toks = re.sub(r'\bcan not\b', 'cannot', toks) | ||
|
||
toks = toks[1:-1].split(' ') # remove the padding spaces and split | ||
toks = [SPELL_FIXES.get(tok, tok) for tok in toks] | ||
|
||
# spelling checks | ||
toks_out = [] | ||
for tok in toks: | ||
if tok not in data_toks and not tok in CUSTOM_DICT and not spellchecker.spell(tok): | ||
resp = ask(data_line, response_text, 'Correct spelling of `%s\' -- [A]ll/[S]kip (text of post-edit): ' % tok) | ||
if resp.upper() == 'S': | ||
toks_out.append(tok) | ||
continue | ||
elif resp.upper().startswith('A '): | ||
resp = resp[2:] | ||
SPELL_FIXES[tok] = resp | ||
toks_out.append(resp) | ||
else: | ||
toks_out.append(tok) | ||
|
||
toks = ' ' + ' '.join(toks_out) + ' ' | ||
|
||
# normalize capitalization | ||
toks = re.sub(r'\bi\b', 'I', toks) | ||
toks = re.sub(r'\bok\b', 'OK', toks) | ||
toks = re.sub(r'([?.!]|^) ([a-z])', | ||
lambda match: match.group(1) + ' ' + match.group(2).upper(), | ||
toks) | ||
|
||
for data_item in data_line.values_list: # capitalization of data items | ||
if data_item == data_item.lower(): | ||
continue | ||
toks = re.sub(r'\b' + data_item + r'\b', data_item, toks, flags=re.IGNORECASE) | ||
|
||
toks = re.sub(r' ([?.!;,:-]+)', r'\1', toks) # remove space before punctuation | ||
toks = toks[1:-1] # remove padding spaces | ||
|
||
# TODO uncomment | ||
# # check if final punctuation matches the type | ||
#if re.match(r'^(confirm|reply|apologize|confirm&reply)$', data_line.dat) and toks.endswith('?'): | ||
#resp = ask(data_line, response_text, "Change final punctuation to `.'? [Y/N]: ") | ||
#if resp.upper() == 'Y': | ||
#toks = toks[:-1] + '.' | ||
#if (re.match(r'^(request|confirm&request)$', data_line.dat) and | ||
#(toks.endswith('.') or toks.endswith('!')) and | ||
#not re.match(r'^(Please (tell|provide|let me know) |I(\'m going to)? need)', toks)): | ||
#resp = ask(data_line, response_text, "Change final punctuation to `?'? [Y/N]: ") | ||
#if resp.upper() == 'Y': | ||
#toks = toks[:-1] + '?' | ||
|
||
# add final punctuation if not present | ||
if toks[-1] not in ['.', '!', '?']: | ||
if re.match(r'^(confirm|reply|apologize|confirm-reply)$', data_line.dat): | ||
toks += '.' | ||
elif re.match(r'^(request|confirm-request)$', data_line.dat): | ||
toks += '?' | ||
|
||
# TODO uncomment | ||
# # fix at 0:30 -> in 0:30 if needed | ||
#if 'departure_time_rel' in data_line.slots and re.search(r'\bat 0:[0-9][0-9]\b', toks): | ||
#resp = ask(data_line, response_text, "Change `at' to `in'? [Y/N]: ") | ||
#if resp.upper() == 'Y': | ||
#toks = re.sub(r'\bat (0:[0-9][0-9])\b', r'in \1', toks) | ||
|
||
return toks | ||
|
||
|
||
def delexicalize(response_text, data_line): | ||
"""Delexicalize (normalized) response text.""" | ||
text = response_text | ||
for slot, value in data_line.slot_values.iteritems(): | ||
if slot in ['alternative', 'num_transfers']: | ||
continue | ||
# check if the value to be delexicalized is actually present | ||
if not any(re.search(r'\b' + val + r'\b', response_text, re.IGNORECASE) | ||
for val in [value] + ALT_VALUES.get(value, [])): | ||
raise ValueError("Cannot find value `%s=%s' in response `%s'!" % (slot, value, response_text)) | ||
# delexicalize | ||
for val in [value] + ALT_VALUES.get(value, []): | ||
text = re.sub(r'\b' + val + r'\b', '*' + slot.upper(), text, flags=re.IGNORECASE) | ||
return text | ||
|
||
|
||
def main(args): | ||
|
||
finished = defaultdict(list) | ||
input_lines = [] | ||
|
||
# load input data | ||
with args.input_file as fh: | ||
csvread = csv.reader(fh, delimiter=str(args.input_csv_delim), | ||
quotechar=b'"', encoding="UTF-8") | ||
columns = DataLine.get_columns_from_header(csvread.next()) | ||
for row in csvread: | ||
input_lines.append(DataLine.from_csv_line(row, columns)) | ||
|
||
# load all results files provided | ||
for finished_file in args.finished_files: | ||
with finished_file as fh: | ||
csvread = csv.reader(fh, delimiter=str(args.finished_csv_delim), | ||
quotechar=b'"', encoding="UTF-8") | ||
header = csvread.next() | ||
columns = DataLine.get_columns_from_header(header) | ||
try: | ||
judgment_column = header.index('check_result') | ||
except ValueError: | ||
judgment_column = None | ||
for row in csvread: | ||
# treat rejected/unchecked as unfinished | ||
if judgment_column is None or not row[judgment_column] or not row[judgment_column][0] in ['Y', 'E']: | ||
continue | ||
# save all accepted finished lines | ||
finished[DataLine.from_csv_line(row, columns).signature].append(Result(row, header)) | ||
|
||
print >> sys.stderr, "Loaded input: %d, Loaded finished: %d" % (len(input_lines), len(finished)) | ||
|
||
# TODO remove this; checking whether we have more answers than we need | ||
#for sig, res in finished.iteritems(): | ||
#if len(res) > 3: | ||
#inp = [line.as_tuple() for line in input_lines if line.signature == sig] | ||
#print inp | ||
#print sig, len(res) | ||
#print '\n'.join([r.reply for r in res]) | ||
|
||
out_lines = [] | ||
out_headers = ['context_utt', 'context_freq', 'context_parse', 'response_da', | ||
'response_nl1', 'response_nl2', 'response_nl3', | ||
'context_utt_l', 'context_parse_l', 'response_da_l', | ||
'response_nl1_l', 'response_nl2_l', 'response_nl3_l'] | ||
|
||
for num, line in enumerate(input_lines): | ||
if not finished[line.signature]: | ||
continue | ||
res = finished[line.signature] | ||
|
||
if len(res) > 3: | ||
print >> sys.stderr, 'WARN: Total %d responses for the signature %s' % (len(res), line.signature) | ||
import pudb; pudb.set_trace() | ||
|
||
# create output line | ||
out_line = Result([''] * len(out_headers), out_headers) | ||
out_line.context_utt = line.abstr_utt | ||
out_line.context_utt_l = line.utt | ||
out_line.context_parse = convert_parse(line) | ||
out_line.context_parse_l = lexicalize_parse(line, out_line.context_parse) | ||
out_line.response_da = convert_da(line, delex=True) | ||
out_line.response_da_l = convert_da(line) | ||
out_line.context_freq = line.occ_num | ||
|
||
out_line.response_nl1_l = normalize_response(res[0].reply, line) | ||
out_line.response_nl1 = delexicalize(out_line.response_nl1_l, line) | ||
printout(str(num) + ' ' + out_line.response_da_l + "\n") | ||
printout(out_line.response_nl1_l + "\n", colour=YELLOW) | ||
|
||
if len(res) > 1: | ||
out_line.response_nl2_l = normalize_response(res[1].reply, line) | ||
out_line.response_nl2 = delexicalize(out_line.response_nl2_l, line) | ||
printout(out_line.response_nl2_l + "\n", colour=YELLOW) | ||
|
||
if len(res) > 2: | ||
out_line.response_nl3_l = normalize_response(res[2].reply, line) | ||
out_line.response_nl3 = delexicalize(out_line.response_nl3_l, line) | ||
printout(out_line.response_nl3_l + "\n", colour=YELLOW) | ||
|
||
out_lines.append(out_line) | ||
|
||
# write the CSV | ||
with sys.stdout as fh: | ||
# starting with the header | ||
csvwrite = csv.writer(fh, delimiter=b",", lineterminator="\n", encoding="UTF-8") | ||
csvwrite.writerow(out_headers) | ||
for out_line in out_lines: | ||
csvwrite.writerow(out_line) | ||
|
||
# TODO write JSON | ||
|
||
|
||
if __name__ == '__main__': | ||
ap = ArgumentParser() | ||
ap.add_argument('-f', '--finished-csv-delim', type=str, default=",") | ||
ap.add_argument('-i', '--input-csv-delim', type=str, default="\t") | ||
ap.add_argument('input_file', type=FileType('r')) | ||
ap.add_argument('output_file_name', type=str) | ||
ap.add_argument('finished_files', type=FileType('r'), nargs='+') | ||
args = ap.parse_args() | ||
main(args) |