Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update from origin #1

Merged
merged 19 commits into from
Dec 10, 2016
9 changes: 4 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,6 @@ Code for the [Recurrent Neural Network Grammars](https://arxiv.org/abs/1602.0777
* [CMake](http://www.cmake.org/)
* [EVALB](http://nlp.cs.nyu.edu/evalb/) (latest version. IMPORTANT: please put the EVALB folder on the same directory as `get_oracle.py` and `sample_input_chinese.txt` to ensure compatibility)

cmake version 2.8+
The latest development version of Eigen
C++ compiler (supporting the C++11 language standard)
Boost libraries

# Build instructions
Assuming the latest development version of Eigen is stored at: /opt/tools/eigen-dev

Expand Down Expand Up @@ -56,6 +51,10 @@ On the English PTB dataset the discriminative model typically converges after ab

nohup build/nt-parser/nt-parser --cnn-mem 1700 -x -T [training_oracle_file] -d [dev_oracle_file] -C [original_dev_file (PTB bracketed format, see sample_input_english.txt)] -P -t --pretrained_dim [dimension of pre-trained word embedding] -w [pre-trained word embedding] --lstm_input_dim 128 --hidden_dim 128 -D 0.2 > log.txt

IMPORTANT: please run the command at the same folder where `remove_dev_unk.py` is located.

If not using pre-trained word embedding, then remove the `--pretrained_dim` and `-w` flags.

The training log is printed to `log.txt` (including information on where the parameter file for the model is saved to, which is used for decoding under the -m option below)

### Decoding with discriminative model
Expand Down
4 changes: 2 additions & 2 deletions nt-parser/nt-parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1076,7 +1076,7 @@ int main(int argc, char** argv) {
double err = (trs - right) / trs;
cerr << "Dev output in " << pfx << endl;
//parser::EvalBResults res = parser::Evaluate("foo", pfx);
std::string command="python ../remove_dev_unk.py "+ corpus.devdata +" "+pfx+" > evaluable.txt";
std::string command="python remove_dev_unk.py "+ corpus.devdata +" "+pfx+" > evaluable.txt";
const char* cmd=command.c_str();
system(cmd);

Expand Down Expand Up @@ -1208,7 +1208,7 @@ int main(int argc, char** argv) {
double err = (trs - right) / trs;
cerr << "Test output in " << pfx << endl;
//parser::EvalBResults res = parser::Evaluate("foo", pfx);
std::string command="python ../remove_dev_unk.py "+ corpus.devdata +" "+pfx+" > evaluable.txt";
std::string command="python remove_dev_unk.py "+ corpus.devdata +" "+pfx+" > evaluable.txt";
const char* cmd=command.c_str();
system(cmd);

Expand Down
66 changes: 66 additions & 0 deletions remove_dev_unk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import sys

def is_next_open_bracket(line, start_idx):
for char in line[(start_idx + 1):]:
if char == '(':
return True
elif char == ')':
return False
raise IndexError('Bracket possibly not balanced, open bracket not followed by closed bracket')

def get_between_brackets(line, start_idx):
output = []
for char in line[(start_idx + 1):]:
if char == ')':
break
assert not(char == '(')
output.append(char)
return ''.join(output)

def get_tags_tokens_lowercase(line):
output = []
#print 'curr line', line_strip
line_strip = line.rstrip()
#print 'length of the sentence', len(line_strip)
for i in range(len(line_strip)):
if i == 0:
assert line_strip[i] == '('
if line_strip[i] == '(' and not(is_next_open_bracket(line_strip, i)): # fulfilling this condition means this is a terminal symbol
output.append(get_between_brackets(line_strip, i))
#print 'output:',output
output_tags = []
output_tokens = []
output_lowercase = []
for terminal in output:
terminal_split = terminal.split()
assert len(terminal_split) == 2 # each terminal contains a POS tag and word
output_tags.append(terminal_split[0])
output_tokens.append(terminal_split[1])
output_lowercase.append(terminal_split[1].lower())
return [output_tags, output_tokens, output_lowercase]

def main():
if len(sys.argv) != 3:
raise NotImplementedError('Program only takes two arguments: the gold dev set and the output file dev set')
gold_file = open(sys.argv[1], 'r')
sys_file = open(sys.argv[2], 'r')
gold_lines = gold_file.readlines()
sys_lines = sys_file.readlines()
gold_file.close()
sys_file.close()
assert len(gold_lines) == len(sys_lines)
for gold_line, sys_line in zip(gold_lines, sys_lines):
gold_tags, gold_tokens, gold_lowercase = get_tags_tokens_lowercase(gold_line)
sys_tags, sys_tokens, sys_lowercase = get_tags_tokens_lowercase(sys_line)
assert len(gold_tokens) == len(gold_tags)
assert len(gold_tokens) == len(gold_lowercase)
assert len(gold_tokens) == len(sys_tokens)
assert len(sys_tokens) == len(sys_tags)
assert len(sys_tags) == len(sys_lowercase)
output_string = sys_line
for gold_token, gold_tag, sys_token in zip(gold_tokens, gold_tags, sys_tokens):
output_string = output_string.replace('(XX ' + sys_token + ')', '(' + gold_tag + ' ' + gold_token + ')', 1)
print output_string.rstrip()

if __name__ == '__main__':
main()