Skip to content

Commit

Permalink
add 'words_ensemble' model
Browse files Browse the repository at this point in the history
  • Loading branch information
shixing committed May 15, 2017
1 parent 912c982 commit ef0a575
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 21 deletions.
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
.DS_Store
*~
[#].[#]
.[#]*
*[#]

*pyc
12 changes: 11 additions & 1 deletion README_XING.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,21 @@ $EXEC -k 10 best.nn kbest_fsa.txt --print-score 1 -b 5 --fsa fsa.txt --print-bea
You can choose one of the following three commend to type in STDIN:

1. `source <source_file>` : process the source-side forward propagation.
2. `words word1 word2 word3` feed the target-side RNN with words sequence `word1 owrd2 word3`. This is supposed to be the line that human composed.
2. `words word1 word2 word3` feed the target-side RNN with words sequence `word1 owrd2 word3`. This is supposed to be the line that human composed.
3. `fsaline <fsa_file> encourage_list_files:enc1.txt,enc2.txt encourage_weights:1.0,-1.0 repetition:0.0 alliteration:0.0 wordlen:0.0` Let the RNN to continue decode with FSA.

Both step 2 and 3 will start from the previous hidden states and cell states of target-side RNN.

You can also ensemble two models `best.nn.1` and `best.nn.2` by:

```
$EXEC -k 10 best.nn.1 best.nn.2 kbest_fsa.txt --print-score 1 -b 5 --fsa fsa.txt --print-beam 1 --decode-main-data-files source.valid.txt source.valid.txt --interactive-line 1 --interactive 1
```

and addtionally, you can use `words_ensemble` option to provide two different human inputs for the two models:

4. `words_ensemble word11 word12 word13 ___sep___ word21 word22 word23 ___sep___` feed the target-side RNN with words sequence `word11 owrd12 word13` for `best.nn.1` and `word21 word22 word23` for `best.nn.2` These are supposed to be the lines human composed.

# Decoding with Word Alignment

Suppose we are translating from French to English, we could use the word alignment information to speed up the decoding. Please find details in 5. [Speeding up Neural Machine Translation Decoding by Shrinking Run-time Vocabulary](http://xingshi.me/data/pdf/ACL2017short.pdf).
Expand Down
4 changes: 2 additions & 2 deletions executable/ZOPH_RNN_XING
Git LFS file not shown
12 changes: 11 additions & 1 deletion scripts/fsa/demo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,20 @@ $EXEC -k 10 best.nn kbest_fsa.txt --print-score 1 -b 5 --fsa fsa.txt --print-bea
# the command line should contains --fsa <fsa_file> and --decode-main-data-files <source_file>, both fsa_file and source_file should exist and are valid fsa_file and source file, although you don't really use them in the interactive mode.

# [Interactive-line mode] : --interactive 1 --interactive-line 1
$EXEC -k 10 best.nn kbest_fsa.txt --print-score 1 -b 5 --fsa fsa.txt --print-beam 1 --decode-main-data-files source.valid.txt --interactive-line 1 --interactive-line 1
$EXEC -k 10 best.nn kbest_fsa.txt --print-score 1 -b 5 --fsa fsa.txt --print-beam 1 --decode-main-data-files source.valid.txt --interactive-line 1 --interactive 1

# 1. `source <source_file>` : process the source-side forward propagation.
# 2. `words word1 word2 word3` feed the target-side RNN with words sequence `word1 owrd2 word3`. This is supposed to be the line that human composed.
# 3. `fsaline <fsa_file> encourage_list_files:enc1.txt,enc2.txt encourage_weights:1.0,-1.0 repetition:0.0 alliteration:0.0 wordlen:0.0` Let the RNN to continue decode with FSA.


# [Interactive-line mode + ensemble ] : --interactive 1 --interactive-line 1
$EXEC -k 10 best.nn best.nn kbest_fsa.txt --print-score 1 -b 5 --fsa fsa.txt --print-beam 1 --decode-main-data-files source.valid.txt source.valid.txt --interactive-line 1 --interactive 1

# 1. `source <source_file>` : process the source-side forward propagation.
# 2. `words word1 word2 word3` feed the target-side RNN with words sequence `word1 owrd2 word3`. This is supposed to be the line that human composed.
# 3. `words_ensemble word11 word12 word13 ___sep___ word21 word22 word23 ___sep___` feed the target-side RNN with words sequence `word11 owrd12 word13` for `best.nn.1` and `word21 word22 word23` for `best.nn.2` This is supposed to be the line that human composed.
# 4. `fsaline <fsa_file> encourage_list_files:enc1.txt,enc2.txt encourage_weights:1.0,-1.0 repetition:0.0 alliteration:0.0 wordlen:0.0` Let the RNN to continue decode with FSA.



Expand Down
6 changes: 5 additions & 1 deletion src/decoder_model_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@ class decoder_model_wrapper {
dType *h_outputdist;
dType *d_temp_swap_vals;
int *d_input_vocab_indicies_source;
int *d_current_indicies;
int *d_current_indicies;

int *h_current_indices; // every model should have this vector for model ensemble;



neuralMT_model<dType> *model; //This is the model

Expand Down
4 changes: 3 additions & 1 deletion src/decoder_model_wrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ decoder_model_wrapper<dType>::decoder_model_wrapper(int gpu_num,int beam_size,

//allocate the current indicies
CUDA_ERROR_WRAPPER(cudaMalloc((void**)&d_current_indicies,beam_size*sizeof(int)),"GPU memory allocation failed\n");

h_current_indices = (int *) malloc(beam_size*sizeof(int));


model = new neuralMT_model<dType>();
//initialize the model
model->initModel_decoding(LSTM_size,beam_size,source_vocab_size,target_vocab_size,
Expand Down
141 changes: 126 additions & 15 deletions src/ensemble_factory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,24 @@ void ensemble_factory<dType>::decode_file_interactive_line() {
// both of the two funcs needs to prepare the follwing two things:
// 1. init the pre_target_states.c_t_pre/h_t_pre as h_2 ( h2 = lstm(w2,h1) )
// 2. init the h_current_indicies = [w3] * beam_size;
// 3. Now, we can ensemble different model, so each models[i] has a h_current_indices: it will records the h_current_indices before and after each function call:
/*
{nothing} -> source -> {models[i].h_current_indices = model_decoder.h_current_indices}
{model_decoder.h_current_indices = models[i].h_current_indices } -> words -> {models[i].h_current_indices = model_decoder.h_current_indices}
{model_decoder.h_current_indices = models[i].h_current_indices } -> words_ensemble -> {models[i].h_current_indices = model_decoder.h_current_indices}
{model_decoder.h_current_indices = models[i].h_current_indices } -> fsaline -> {models[i].h_current_indices = model_decoder.h_current_indices}
*/


while (true) {
// 1. source <source_file> -> [END]
// 2. words <words> -> [END]
// 3. fsa <fsa_file> encourage_list_files:enc1.txt,enc2.txt encourage_weights:1.0,-1.0 repetition:0.0 alliteration:0.0 wordlen:0.0 -> [END] : as normal
// 2. words_ensemble <words> ___sep___ <words> ___sep___ -> [END]
// 3. (removed )fsa <fsa_file> encourage_list_files:enc1.txt,enc2.txt encourage_weights:1.0,-1.0 repetition:0.0 alliteration:0.0 wordlen:0.0 -> [END] : as normal
// 4. fsaline <fsa_file> encourage_list_files:enc1.txt,enc2.txt encourage_weights:1.0,-1.0 repetition:0.0 alliteration:0.0 wordlen:0.0 -> [END]: as noraml, but at the end, move corresponding ct and ht to all beams.


std::cout<<"Please input <source/words/fsa/fsaline> <source_file/words/fsa_file>\n";
std::cout<<"Please input <source/words/words_ensemble/fsaline> <source_file/words/words seperated by ___sep___/fsa_file>\n";
std::cout.flush();
// read input
// input format:
Expand Down Expand Up @@ -137,6 +145,13 @@ void ensemble_factory<dType>::decode_file_interactive_line() {
for(int j=0; j < models.size(); j++) {
models[j].forward_prop_source();
}

//copy model_decoder->h_current_indicies to each model's h_current_indicies;
for(int j=0; j < models.size(); j++) {
for (int k = 0; k < model_decoder->beam_size; k += 1 ){
models[j].h_current_indices[k] = model_decoder->h_current_indices[k];
}
}

std::cout<<"[END]\n";
std::cout.flush();
Expand Down Expand Up @@ -164,14 +179,23 @@ void ensemble_factory<dType>::decode_file_interactive_line() {
word_indices.push_back(word_index);
}



for (int i = 0; i< word_indices.size() ; i ++){
// init model_decoder->h_current_indices with each modle's h_current_indices;
for (int k = 0; k < model_decoder->beam_size; k += 1 ){
model_decoder->h_current_indices[k] = models[0].h_current_indices[k] ;
}

std::cout<< "WI: "<< model_decoder->h_current_indices[0] << "\n";
for (int i = 0; i< word_indices.size() ; i ++){

for(int j=0; j < models.size(); j++) {
if (i == 0){
// for words_ensemble, different model have different h_current_indices;
for (int k = 0; k < model_decoder->beam_size; k += 1 ){
model_decoder->h_current_indices[k] = models[j].h_current_indices[k] ;
}
}

std::cout<< "WI["<<j<<"]: "<< model_decoder->h_current_indices[0] << "\n";

models[j].forward_prop_target(curr_index+i,model_decoder->h_current_indices);
models[j].target_copy_prev_states();
}
Expand All @@ -182,16 +206,87 @@ void ensemble_factory<dType>::decode_file_interactive_line() {
model_decoder->h_current_indices[j] = word_index;
}


}

// update each modle's h_current_indices with model_decoder->h_current_indices;
for(int j=0; j < models.size(); j++) {
for (int k = 0; k < model_decoder->beam_size; k += 1 ){
models[j].h_current_indices[k] = model_decoder->h_current_indices[k];
}
}

std::cout<<"[END]\n";
std::cout.flush();

right_after_encoding = false;

} else if (action == "fsa") {
} else if (action == "words_ensemble"){
std::vector<std::vector<int>> word_indices_array;
for (int i = 0; i< models.size(); i++){
std::vector<int> temp;
word_indices_array.push_back(temp);
}

int curr_index = 0;

if (right_after_encoding){
curr_index = 0;
} else {
curr_index = 1;
}

int i_sentence = 0;
for (int i = 1; i < ll.size(); i +=1 ){
std::string word = ll[i];
if (word == "___sep___"){
i_sentence+=1;
continue;
}
int word_index = 2; // <UNK>
if (model_decoder->tgt_mapping.count(word) > 0){
word_index = model_decoder->tgt_mapping[word];
}
word_indices_array[i_sentence].push_back(word_index);
}

for (int j=0; j < word_indices_array.size(); j += 1){
std::vector<int> & word_indices = word_indices_array[j];

// init model_decoder->h_current_indices with each modle's h_current_indices;
for (int k = 0; k < model_decoder->beam_size; k += 1 ){
model_decoder->h_current_indices[k] = models[j].h_current_indices[k] ;
}

for (int i = 0; i< word_indices.size() ; i ++){


std::cout<< "WI["<<j<<"]: "<< model_decoder->h_current_indices[0] << "\n";

models[j].forward_prop_target(curr_index+i,model_decoder->h_current_indices);
models[j].target_copy_prev_states();

int word_index = word_indices[i];

for (int j=0 ; j< model_decoder->beam_size; j++){
model_decoder->h_current_indices[j] = word_index;
}

}

// update each modle's h_current_indices with model_decoder->h_current_indices;
for (int k = 0; k < model_decoder->beam_size; k += 1 ){
models[j].h_current_indices[k] = model_decoder->h_current_indices[k];
}

}


std::cout<<"[END]\n";
std::cout.flush();

right_after_encoding = false;

} /*else if (action == "fsa") {
fsa_file = ll[1];
model_decoder->init_fsa_interactive(fsa_file);
Expand Down Expand Up @@ -237,10 +332,8 @@ void ensemble_factory<dType>::decode_file_interactive_line() {
//process wordlen weight
model_decoder->wordlen_weight = wordlen_weight;


decode_file_line(right_after_encoding,false);
//read output and print into stdout;
input_file_prep input_helper;
input_helper.unint_file(p_params->model_names[0],p_params->decoder_output_file,p_params->decoder_final_file,false,true);
Expand All @@ -264,7 +357,7 @@ void ensemble_factory<dType>::decode_file_interactive_line() {
right_after_encoding = false;
} else if (action == "fsaline") {
} */ else if (action == "fsaline") {

fsa_file = ll[1];

Expand Down Expand Up @@ -343,7 +436,7 @@ void ensemble_factory<dType>::decode_file_interactive_line() {

template<typename dType>
void ensemble_factory<dType>::decode_file_line(bool right_after_encoding, bool end_transfer) {
// right_after_encoding = true, means the system is never decoding a word,
// right_after_encoding = true, means the system hasn't decoded a word,
//
bool pre_end_transfer = model_decoder->end_transfer;
model_decoder->end_transfer = end_transfer;
Expand All @@ -363,7 +456,16 @@ void ensemble_factory<dType>::decode_file_line(bool right_after_encoding, bool e
for(int j=0; j < models.size(); j++) {
// curr_index: whether it's 0 or non-0. Doesn't matter if it's 1 or 2 or 3.
// &c_t_pre = &pre_state ; c_t = f(c_t_pre)

if (curr_index == 0){
// for words_ensemble, different model have different h_current_indices;
for (int k = 0; k < model_decoder->beam_size; k += 1 ){
model_decoder->h_current_indices[k] = models[j].h_current_indices[k] ;
}
}

models[j].forward_prop_target(curr_index+start_index,model_decoder->h_current_indices);

}


Expand Down Expand Up @@ -403,6 +505,15 @@ void ensemble_factory<dType>::decode_file_line(bool right_after_encoding, bool e
model_decoder->output_k_best_hypotheses(models[0].fileh->sentence_length);
//model_decoder->print_current_hypotheses();
model_decoder->end_transfer = pre_end_transfer;

// update each modle's h_current_indices with model_decoder->h_current_indices;
for(int j=0; j < models.size(); j++) {
for (int k = 0; k < model_decoder->beam_size; k += 1 ){
models[j].h_current_indices[k] = model_decoder->h_current_indices[k];
}
}


}


Expand Down

0 comments on commit ef0a575

Please sign in to comment.