Skip to content
This repository was archived by the owner on May 25, 2020. It is now read-only.

Commit 1efee48

Browse files
author
Nicolas
authored
Cakechat refactoring (#23)
CakeChat refactoring
1 parent ec68708 commit 1efee48

34 files changed

+530
-448
lines changed

cakechat/config.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import os
22

3-
from cakechat.utils.env import is_dev_env
43
from cakechat.utils.data_structures import create_namedtuple_instance
4+
from cakechat.utils.env import is_dev_env
55

6-
RANDOM_SEED = 42 # Fix the random seed to a certain value to make everything reproducable
6+
RANDOM_SEED = 42 # Fix the random seed to a certain value to make everything reproducible
77

88
# AWS S3 params
99
S3_MODELS_BUCKET_NAME = 'cake-chat-data' # S3 bucket with all the data
@@ -15,6 +15,7 @@
1515
# data params
1616
DATA_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'data') # Directory to store all the data
1717
# e.g. datasets, models, indices
18+
NN_MODELS_DIR = os.path.join(DATA_DIR, 'nn_models') # Path to a directory for saving and restoring dialog models
1819
PROCESSED_CORPUS_DIR = os.path.join(DATA_DIR, 'corpora_processed') # Path to a processed corpora datasets
1920
TOKEN_INDEX_DIR = os.path.join(DATA_DIR, 'tokens_index') # Path to a prepared tokens index file
2021
CONDITION_IDS_INDEX_DIR = os.path.join(DATA_DIR, 'conditions_index') # Path to a prepared conditions index file
@@ -24,12 +25,13 @@
2425
TRAIN_CORPUS_NAME = 'train_' + BASE_CORPUS_NAME # Corpus name prefix for the training dataset
2526
CONTEXT_SENSITIVE_VAL_CORPUS_NAME = 'val_' + BASE_CORPUS_NAME # Corpus name prefix for the validation dataset
2627

27-
VAL_SUBSET_SIZE = 250 # Subset from the validation dataset to be used in validation metrics calculation
28+
MAX_VAL_LINES_NUM = 10000 # Max lines number from validation set to be used for metrics calculation
29+
VAL_SUBSET_SIZE = 250 # Subset from the validation dataset to be used to calculated some validation metrics
2830
TRAIN_SUBSET_SIZE = int(os.environ['SLICE_TRAINSET']) if 'SLICE_TRAINSET' in os.environ else None # Subset from the
2931
# training dataset to be used during the training. In case of None use all lines in the train dataset (default behavior)
3032

3133
# test data paths
32-
TEST_DATA_DIR = os.path.join(DATA_DIR, 'quality') # Path to datasets for quality metrics calculation
34+
TEST_DATA_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'data', 'quality')
3335
CONTEXT_FREE_VAL_CORPUS_NAME = 'context_free_validation_set' # Context-free validation set path
3436
TEST_CORPUS_NAME = 'context_free_test_set' # Context-free test set path
3537
QUESTIONS_CORPUS_NAME = 'context_free_questions' # Context-free questions only path
@@ -61,9 +63,12 @@
6163
OUTPUT_SEQUENCE_LENGTH = 32 # Output sequence length. Better to keep as INPUT_SEQUENCE_LENGTH+2 for start/end tokens
6264
BATCH_SIZE = 192 # Default batch size which fits into 8GB of GPU memory
6365
SHUFFLE_TRAINING_BATCHES = True # Shuffle training batches in the dataset each epoch
64-
EPOCHES_NUM = 100 # Total epochs num
66+
EPOCHS_NUM = 100 # Total epochs num
6567
GRAD_CLIP = 5.0 # Gradient clipping passed into theano.gradient.grad_clip()
66-
ADADELTA_LEARNING_RATE = 1.0 # Initial AdaDelta learning rate
68+
LEARNING_RATE = 1.0 # Learning rate for the chosen optimizer (currently using Adadelta, see model.py)
69+
70+
# model params
71+
NN_MODEL_PREFIX = 'cakechat' # Specify prefix to be prepended to model's name
6772

6873
# predictions params
6974
MAX_PREDICTIONS_LENGTH = 40 # Max. number of tokens which can be generated on the prediction step
@@ -94,8 +99,10 @@
9499
LOG_CANDIDATES_NUM = 10 # Number of candidates to be printed to output during the logging
95100
SCREEN_LOG_NUM_TEST_LINES = 10 # Number of first test lines to use when logging outputs on screen
96101
SCREEN_LOG_FREQUENCY_PER_BATCHES = 500 # How many batches to train until next logging of output on screen
97-
LOG_FREQUENCY_PER_BATCHES = 2500 # How many batches to train until next logging of all the output into file
98-
LOG_LOSS_DECAY = 0.99 # Decay for the averaging the loss which is printed in logs
102+
LOG_TO_TB_FREQUENCY_PER_BATCHES = 500 # How many batches to train until next metrics computed for TensorBoard
103+
LOG_TO_FILE_FREQUENCY_PER_BATCHES = 2500 # How many batches to train until next logging of all the output into file
104+
SAVE_MODEL_FREQUENCY_PER_BATCHES = 2500 # How many batches to train until next logging of all the output into file
105+
AVG_LOSS_DECAY = 0.99 # Decay for the averaging the loss
99106

100107
# Use reduced sizes for input/output sequences, hidden layers and datasets sizes for the 'Developer Mode'
101108
if is_dev_env():
@@ -105,10 +112,13 @@
105112
BATCH_SIZE = 128
106113
HIDDEN_LAYER_DIMENSION = 7
107114
SCREEN_LOG_FREQUENCY_PER_BATCHES = 2
108-
LOG_FREQUENCY_PER_BATCHES = 3
115+
LOG_TO_TB_FREQUENCY_PER_BATCHES = 3
116+
LOG_TO_FILE_FREQUENCY_PER_BATCHES = 4
117+
SAVE_MODEL_FREQUENCY_PER_BATCHES = 4
109118
WORD_EMBEDDING_DIMENSION = 15
110119
SAMPLES_NUM_FOR_RERANKING = BEAM_SIZE = 5
111120
LOG_CANDIDATES_NUM = 3
112121
USE_PRETRAINED_W2V_EMBEDDINGS_LAYER = False
113122
VAL_SUBSET_SIZE = 100
123+
MAX_VAL_LINES_NUM = 100
114124
TRAIN_SUBSET_SIZE = 10000

cakechat/dialog_model/factory.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import os
22

3+
from cachetools import cached
4+
35
from cakechat.config import BASE_CORPUS_NAME, S3_MODELS_BUCKET_NAME, S3_TOKENS_IDX_REMOTE_DIR, \
46
S3_NN_MODEL_REMOTE_DIR, S3_CONDITIONS_IDX_REMOTE_DIR
57
from cakechat.dialog_model.model import get_nn_model
68
from cakechat.utils.s3 import S3FileResolver
9+
from cakechat.utils.files_utils import FileNotFoundException
710
from cakechat.utils.text_processing import get_index_to_token_path, load_index_to_item, get_index_to_condition_path
811

912

@@ -12,10 +15,10 @@ def _get_index_to_token(fetch_from_s3):
1215
if fetch_from_s3:
1316
tokens_idx_resolver = S3FileResolver(index_to_token_path, S3_MODELS_BUCKET_NAME, S3_TOKENS_IDX_REMOTE_DIR)
1417
if not tokens_idx_resolver.resolve():
15-
raise Exception('Can\'t get index_to_token because file does not exist at S3')
18+
raise FileNotFoundException('Can\'t get index_to_token because file does not exist at S3')
1619
else:
1720
if not os.path.exists(index_to_token_path):
18-
raise Exception('Can\'t get index_to_token because file does not exist. '
21+
raise FileNotFoundException('Can\'t get index_to_token because file does not exist. '
1922
'Run tools/download_model.py first to get all required files or construct it by yourself.')
2023

2124
return load_index_to_item(index_to_token_path)
@@ -27,30 +30,28 @@ def _get_index_to_condition(fetch_from_s3):
2730
index_to_condition_resolver = S3FileResolver(index_to_condition_path, S3_MODELS_BUCKET_NAME,
2831
S3_CONDITIONS_IDX_REMOTE_DIR)
2932
if not index_to_condition_resolver.resolve():
30-
raise Exception('Can\'t get index_to_condition because file does not exist at S3')
33+
raise FileNotFoundException('Can\'t get index_to_condition because file does not exist at S3')
3134
else:
3235
if not os.path.exists(index_to_condition_path):
33-
raise Exception('Can\'t get index_to_condition because file does not exist. '
36+
raise FileNotFoundException('Can\'t get index_to_condition because file does not exist. '
3437
'Run tools/download_model.py first to get all required files or construct it by yourself.')
3538

3639
return load_index_to_item(index_to_condition_path)
3740

3841

42+
@cached(cache={})
3943
def get_trained_model(reverse=False, fetch_from_s3=True):
4044
if fetch_from_s3:
4145
resolver_factory = S3FileResolver.init_resolver(
4246
bucket_name=S3_MODELS_BUCKET_NAME, remote_dir=S3_NN_MODEL_REMOTE_DIR)
4347
else:
4448
resolver_factory = None
4549

46-
nn_model, model_exists = get_nn_model(
47-
_get_index_to_token(fetch_from_s3),
48-
_get_index_to_condition(fetch_from_s3),
49-
resolver_factory=resolver_factory,
50-
is_reverse_model=reverse)
51-
50+
nn_model, model_exists = get_nn_model(index_to_token=_get_index_to_token(fetch_from_s3),
51+
index_to_condition=_get_index_to_condition(fetch_from_s3),
52+
resolver_factory=resolver_factory,
53+
is_reverse_model=reverse)
5254
if not model_exists:
53-
raise Exception('Can\'t get the model. '
54-
'Run tools/download_model.py first to get all required files or train it by yourself.')
55-
55+
raise FileNotFoundException('Can\'t get the pre-trained model. Run tools/download_model.py first '
56+
'to get all required files or train it by yourself.')
5657
return nn_model

cakechat/dialog_model/inference/candidates/beamsearch.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
1-
from six.moves import zip_longest
2-
31
import numpy as np
4-
from six.moves import xrange
2+
from six.moves import xrange, zip_longest
53
import theano
64

75
from cakechat.dialog_model.inference.candidates.abstract_generator import AbstractCandidatesGenerator
@@ -87,7 +85,7 @@ def _update_next_candidates_and_hidden_states(self, token_idx, best_non_finished
8785
# We need to get which original candidate this token in the expanded beam corresponds to.
8886
# (to fill in all the previous tokens from self._cur_candidates)
8987
# Because all the candidates in the expanded beam were filled sequentially, we just use this formula:
90-
original_candidate_idx = candidate_idx / self._beam_size
88+
original_candidate_idx = candidate_idx // self._beam_size
9189

9290
# Construct the candidates for the next step using self._cur_candidates and the last token:
9391

@@ -123,7 +121,7 @@ def _update_finished_candidates(self, token_idx, best_finished_candidates_indice
123121
# to get all the other tokens we need to get which original candidate this token in the expanded beam
124122
# corresponds to. Because all the candidates in the expanded beam were filled sequentially, we can just
125123
# use this formula:
126-
original_candidate_idx = candidate_idx / self._beam_size
124+
original_candidate_idx = candidate_idx // self._beam_size
127125

128126
# Construct the candidates for the next step using self._cur_candidates and the last token:
129127

cakechat/dialog_model/inference/factory.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,6 @@
55
from cakechat.dialog_model.inference.reranking import DummyReranker, MMIReranker
66

77

8-
def _get_reverse_model():
9-
if not hasattr(_get_reverse_model, 'reverse_model'):
10-
try:
11-
reverse_model = get_trained_model(reverse=True)
12-
except:
13-
raise ValueError('Can\'t get reverse nn model for prediction. '
14-
'Try to run \'python tools/train.py --reverse\' or switch prediction mode to sampling.')
15-
_get_reverse_model.reverse_model = reverse_model
16-
return _get_reverse_model.reverse_model
17-
18-
198
def predictor_factory(nn_model, mode, config):
209
"""
2110
@@ -39,7 +28,7 @@ def predictor_factory(nn_model, mode, config):
3928
if config['mmi_reverse_model_score_weight'] <= 0:
4029
raise ValueError('mmi_reverse_model_score_weight should be > 0 for reranking mode')
4130

42-
reverse_model = _get_reverse_model()
31+
reverse_model = get_trained_model(reverse=True)
4332
reranker = MMIReranker(nn_model, reverse_model, config['mmi_reverse_model_score_weight'],
4433
config['repetition_penalization_coefficient'])
4534
else:

cakechat/dialog_model/inference/predict.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
import numpy as np
2-
from six.moves import xrange
3-
41
from cakechat.config import MAX_PREDICTIONS_LENGTH, BEAM_SIZE, MMI_REVERSE_MODEL_SCORE_WEIGHT, DEFAULT_TEMPERATURE, \
52
SAMPLES_NUM_FOR_RERANKING, PREDICTION_MODES, REPETITION_PENALIZE_COEFFICIENT
63
from cakechat.dialog_model.inference.factory import predictor_factory
@@ -54,7 +51,7 @@ def get_nn_response_ids(context_token_ids,
5451
"""
5552
Predicts several responses for every context.
5653
57-
:param context_token_ids: np.array; shape=(batch_size x context_size x context_len); dtype=int
54+
:param context_token_ids: np.array; shape (batch_size, context_size, context_len); dtype=int
5855
Represents all tokens ids to use for predicting
5956
:param nn_model: CakeChatModel
6057
:param mode: one of PREDICTION_MODES mode
@@ -65,7 +62,7 @@ def get_nn_response_ids(context_token_ids,
6562
:param output_seq_len: Number of tokens to generate.
6663
:param kwargs: Other prediction parameters, passed into predictor constructor.
6764
Might be different depending on mode. See PredictionConfig for the details.
68-
:return: np.array; shape=(responses_num x output_candidates_num x output_seq_len); dtype=int
65+
:return: np.array; shape (batch_size, output_candidates_num, output_seq_len); dtype=int
6966
Generated predictions.
7067
"""
7168
if mode == PREDICTION_MODES.sampling:
@@ -75,8 +72,9 @@ def get_nn_response_ids(context_token_ids,
7572
_logger.debug('Generating predicted response for the following params: %s' % prediction_config)
7673

7774
predictor = predictor_factory(nn_model, mode, prediction_config.get_options_dict())
78-
return np.array(
79-
predictor.predict_responses(context_token_ids, output_seq_len, condition_ids, output_candidates_num))
75+
responses = predictor.predict_responses(context_token_ids, output_seq_len, condition_ids, output_candidates_num)
76+
77+
return responses
8078

8179

8280
def get_nn_responses(context_token_ids,
@@ -87,19 +85,25 @@ def get_nn_responses(context_token_ids,
8785
condition_ids=None,
8886
**kwargs):
8987
"""
90-
Predicts several responses for every context and returns them as proccessed strings.
88+
Predicts output_candidates_num responses for every context and returns them in form of strings.
9189
See get_nn_response_ids for the details.
9290
93-
:return: list of lists of strings
94-
Generated predictions.
91+
:param context_token_ids: numpy array of integers, shape (contexts_num, INPUT_CONTEXT_SIZE, INPUT_SEQUENCE_LENGTH)
92+
:param nn_model: trained model
93+
:param mode: prediction mode, see const PREDICTION_MODES
94+
:param output_candidates_num: number of responses to be generated for each context
95+
:param output_seq_len: max length of generated responses
96+
:param condition_ids: extra info to be taken into account while generating response (emotion, for example)
97+
98+
:return: list of lists of strings, shape (contexts_num, output_candidates_num)
9599
"""
96-
response_tokens_ids = get_nn_response_ids(context_token_ids, nn_model, mode, output_candidates_num, output_seq_len,
97-
condition_ids, **kwargs)
98-
# Reshape to get list of lines to supply into transform_token_ids_to_sentences
99-
response_tokens_ids = np.reshape(response_tokens_ids, (-1, output_seq_len))
100-
response_tokens = transform_token_ids_to_sentences(response_tokens_ids, nn_model.index_to_token)
101-
102-
lines_num = len(response_tokens) // output_candidates_num
103-
responses = [response_tokens[i * output_candidates_num:(i + 1) * output_candidates_num] for i in xrange(lines_num)]
100+
101+
response_tokens_ids = get_nn_response_ids(context_token_ids, nn_model, mode, output_candidates_num,
102+
output_seq_len, condition_ids, **kwargs)
103+
# shape (contexts_num, output_candidates_num, output_seq_len), numpy array of integers
104+
105+
responses = [transform_token_ids_to_sentences(response_candidates_tokens_ids, nn_model.index_to_token)
106+
for response_candidates_tokens_ids in response_tokens_ids]
107+
# responses shape (contexts_num, output_candidates_num), list of lists of strings
104108

105109
return responses

cakechat/dialog_model/inference/predictor.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@ def _select_best_candidates(reranked_candidates, candidates_num):
1515
If for some context we generated less then candidates_num candidates, we fill this responses with pads.
1616
"""
1717
batch_size = len(reranked_candidates)
18-
# reranked_candidates is list of lists (we need too keep it this way because we can have different number
19-
# of candidates for each context), so we can't just write rerankied_candidates.shape[2]
18+
# reranked_candidates is a list of lists (we need too keep it this way because we can have different number
19+
# of candidates for each context), so we can't just write reranked_candidates.shape[2]
2020
output_seq_len = reranked_candidates[0][0].size
21-
result = np.zeros((batch_size, candidates_num, output_seq_len))
22-
# Loop here instead of slices because number of candidates for each context can vary here
21+
result = np.zeros((batch_size, candidates_num, output_seq_len), dtype=np.int32)
22+
# Loop here instead of slices because number of candidates for each context may vary here
2323
for i in xrange(batch_size):
2424
for j, candidate in enumerate(reranked_candidates[i]):
2525
if j >= candidates_num:
@@ -30,4 +30,5 @@ def _select_best_candidates(reranked_candidates, candidates_num):
3030
def predict_responses(self, context_token_ids, output_seq_len, condition_ids=None, candidates_num=1):
3131
all_candidates = self._generator.generate_candidates(context_token_ids, condition_ids, output_seq_len)
3232
reranked_candidates = self._reranker.rerank_candidates(context_token_ids, all_candidates, condition_ids)
33-
return self._select_best_candidates(reranked_candidates, candidates_num)
33+
selected_responses = self._select_best_candidates(reranked_candidates, candidates_num)
34+
return selected_responses

cakechat/dialog_model/inference/reranking.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
from abc import ABCMeta, abstractmethod
2-
from six.moves import zip_longest
32

43
import numpy as np
5-
from six.moves import xrange
4+
from six.moves import xrange, zip_longest
65

76
from cakechat.dialog_model.inference.service_tokens import ServiceTokensIDs
87
from cakechat.dialog_model.inference.utils import get_sequence_score_by_thought_vector, get_sequence_score, \
98
get_thought_vectors
109
from cakechat.dialog_model.model_utils import reverse_nn_input
11-
from cakechat.utils.dataset_loader import Dataset
10+
from cakechat.utils.data_types import Dataset
1211
from cakechat.utils.logger import get_logger
1312
from cakechat.utils.profile import timer
1413

@@ -51,21 +50,18 @@ def __init__(self, nn_model, reverse_model, mmi_reverse_model_score_weight, repe
5150
self._service_tokens_ids = ServiceTokensIDs(nn_model.token_to_index)
5251
self._log_repetition_penalization_coefficient = np.log(repetition_penalization_coefficient)
5352

54-
@timer
5553
def _compute_likelihood_of_output_given_input(self, thought_vector, candidates, condition_id):
5654
# Repeat to get same thought vector for each candidate
5755
thoughts_batch = np.repeat(thought_vector, candidates.shape[0], axis=0)
5856
return get_sequence_score_by_thought_vector(self._nn_model, thoughts_batch, candidates, condition_id)
5957

60-
@timer
6158
def _compute_likelihood_of_input_given_output(self, context, candidates, condition_id):
6259
# Repeat to get same context for each candidate
6360
repeated_context = np.repeat(context, candidates.shape[0], axis=0)
6461
reversed_dataset = reverse_nn_input(
6562
Dataset(x=repeated_context, y=candidates, condition_ids=None), self._service_tokens_ids)
6663
return get_sequence_score(self._reverse_model, reversed_dataset.x, reversed_dataset.y, condition_id)
6764

68-
@timer
6965
def _compute_num_repetitions(self, candidates):
7066
skip_tokens_ids = \
7167
self._service_tokens_ids.special_tokens_ids + self._service_tokens_ids.non_penalizable_tokens_ids
@@ -76,9 +72,7 @@ def _compute_num_repetitions(self, candidates):
7672
result.append(num_repetitions)
7773
return np.array(result)
7874

79-
@timer
8075
def _compute_candidates_scores(self, context, candidates, condition_id):
81-
_logger.info('Reranking {} candidates...'.format(candidates.shape[0]))
8276
context = context[np.newaxis, :] # from (seq_len,) to (1 x seq_len)
8377
thought_vector = get_thought_vectors(self._nn_model, context)
8478

cakechat/dialog_model/inference/tests/predict.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import sys
33
import unittest
4+
45
import numpy as np
56
from six.moves import xrange
67

0 commit comments

Comments
 (0)