Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
299c1bc
workaround for train_set batching during inference time
varisd Jan 9, 2019
7a62312
added batching schemes from tensor2tensor
varisd Jan 30, 2019
a97affc
fixing failed travis tests
varisd Jan 30, 2019
619294f
fixing mypy and pylint errors + noise preprocessor refactor
varisd Oct 15, 2018
fc89074
morphing test files, fixing encountered bugs
jindrahelcl Sep 8, 2018
9b54b01
fixing some bugs, tests, and linters
jindrahelcl Sep 8, 2018
34eaf49
fixed options for bucketed batching; added docstring
varisd Apr 13, 2018
5da678b
implemented token_level batch_size
varisd Apr 18, 2018
ff98b28
bucketed token-level batching tested
varisd Sep 11, 2018
298c432
dataset refactored into three modules
jindrahelcl Jun 4, 2018
16d0a6c
implementation of EWC, gradient_runner and gradient averaging script
varisd Jul 25, 2018
5bcf362
fixed tests
varisd Jul 26, 2018
5dec44f
addressing PR reviews
varisd Jul 31, 2018
872e029
generic_trainer always adds l2 values to summaries
varisd Aug 1, 2018
1c8a963
addressing PR reviews + fixed variable fetching in EWCRegularizer
varisd Aug 7, 2018
f848ecf
fixed pylints in generic_trainer, fixed typos
varisd Aug 8, 2018
19f6971
removed predefined L1, L2 regularizers
varisd Aug 9, 2018
74ae170
removed squaring of gradients in EWCRegularizer
varisd Aug 17, 2018
3c8302e
added script to compute Empirical Fisher
varisd Aug 17, 2018
dbe24ec
naacl19 EWC branch cleanup
varisd Jan 30, 2019
c676d9b
rebased EWC to master
varisd Jan 30, 2019
d826ad7
bugfix in DelayedUpdateTrainer
varisd Jan 31, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 81 additions & 7 deletions neuralmonkey/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,84 @@ def __init__(self,
# pylint: enable=too-few-public-methods


def _bucket_boundaries(max_length, min_length=8, length_bucket_step=1.1):
"""Create a default set of length-bucket boundaries."""
assert length_bucket_step > 1.0
x = min_length
boundaries = []
while x < max_length:
boundaries.append(x)
x = max(x + 1, int(x * length_bucket_step))
return boundaries


def get_batching_scheme(batch_size: int,
max_length: int = None,
min_length_bucket: int = 8,
length_bucket_step: float = 1.1,
shard_multiplier: int = 1,
length_multiplier: int = 1,
min_length: int = 0) -> BatchingScheme:
"""Create a batching scheme based on model hyperparameters.

Every batch contains a number of sequences divisible by `shard_multiplier`.

Args:
batch_size: int, total number of tokens in a batch.
max_length: int, sequences longer than this will be skipped. Defaults
to batch_size.
min_length_bucket: int
length_bucket_step: float greater than 1.0
shard_multiplier: an integer increasing the batch_size to suit
splitting across datashards.
length_multiplier: an integer multiplier that is used to increase the
batch sizes and sequence length tolerance.
min_length: int, sequences shorter than this will be skipped.
Return:
A dictionary with parameters that can be passed to input_pipeline:
* boundaries: list of bucket boundaries
* batch_sizes: list of batch sizes for each length bucket
* max_length: int, maximum length of an example
Raises:
ValueError: If min_length > max_length
"""
max_length = max_length or batch_size
if max_length < min_length:
raise ValueError("max_length must be greater or equal to min_length")

boundaries = _bucket_boundaries(max_length, min_length_bucket,
length_bucket_step)
boundaries = [boundary * length_multiplier for boundary in boundaries]
max_length *= length_multiplier

batch_sizes = [
max(1, batch_size // length) for length in boundaries + [max_length]
]
max_batch_size = max(batch_sizes)
# Since the Datasets API only allows a single constant for window_size,
# and it needs divide all bucket_batch_sizes, we pick a highly-composite
# window size and then round down all batch sizes to divisors of that
# window size, so that a window can always be divided evenly into batches.
highly_composite_numbers = [
1, 2, 4, 6, 12, 24, 36, 48, 60, 120, 180, 240, 360, 720, 840, 1260,
1680, 2520, 5040, 7560, 10080, 15120, 20160, 25200, 27720, 45360,
50400, 55440, 83160, 110880, 166320, 221760, 277200, 332640, 498960,
554400, 665280, 720720, 1081080, 1441440, 2162160, 2882880, 3603600,
4324320, 6486480, 7207200, 8648640, 10810800, 14414400, 17297280,
21621600, 32432400, 36756720, 43243200, 61261200, 73513440, 110270160
]
window_size = max(
[i for i in highly_composite_numbers if i <= 3 * max_batch_size])
divisors = [i for i in range(1, window_size + 1) if window_size % i == 0]
batch_sizes = [max([d for d in divisors if d <= bs]) for bs in batch_sizes]
window_size *= shard_multiplier
batch_sizes = [bs * shard_multiplier for bs in batch_sizes]

ret = BatchingScheme(bucket_boundaries=boundaries,
bucket_batch_sizes=batch_sizes)
return ret


# The protected functions below are designed to convert the ambiguous spec
# structures to a normalized form.

Expand Down Expand Up @@ -311,8 +389,8 @@ def itergen():
for s_name, (preprocessor, source) in prep_sl.items():
if source not in iterators:
raise ValueError(
"Source series for series-level preprocessor nonexistent: "
"Preprocessed series '{}', source series '{}'")
"Source series {} for series-level preprocessor nonexistent: "
"Preprocessed series '', source series ''".format(source))
iterators[s_name] = _make_sl_iterator(source, preprocessor)

# Finally, dataset-level preprocessors.
Expand Down Expand Up @@ -365,8 +443,6 @@ def __init__(self,
Arguments:
name: The name for the dataset.
iterators: A series-iterator generator mapping.
lazy: If False, load the data from iterators to a list and store
the list in memory.
buffer_size: Use this tuple as a minimum and maximum buffer size
for pre-loading data. This should be (a few times) larger than
the batch size used for mini-batching. When the buffer size
Expand Down Expand Up @@ -562,9 +638,7 @@ def itergen():
buf.append(item)

if self.shuffled:
lbuf = list(buf)
random.shuffle(lbuf)
buf = deque(lbuf)
random.shuffle(buf) # type: ignore

if not self.batching.drop_remainder:
for bucket in buckets:
Expand Down
5 changes: 4 additions & 1 deletion neuralmonkey/learning_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from termcolor import colored

from neuralmonkey.logging import log, log_print, warn
from neuralmonkey.dataset import Dataset
from neuralmonkey.dataset import Dataset, BatchingScheme
from neuralmonkey.tf_manager import TensorFlowManager
from neuralmonkey.runners.base_runner import (
BaseRunner, ExecutionResult, GraphExecutor, OutputSeries)
Expand Down Expand Up @@ -85,6 +85,9 @@ def training_loop(cfg: Namespace) -> None:
trainer_result = cfg.tf_manager.execute(
batch, feedables, cfg.trainers, train=True,
summaries=True)
# workaround: we need to use validation batching scheme
# during evaluation
batch.batching = BatchingScheme(batch_size=cfg.batch_size)
train_results, train_outputs, f_batch = run_on_dataset(
cfg.tf_manager, cfg.runners, cfg.dataset_runner, batch,
cfg.postprocess, write_out=False)
Expand Down
9 changes: 1 addition & 8 deletions neuralmonkey/model/parameterized.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import tensorflow as tf

from neuralmonkey.tf_utils import update_initializers
from neuralmonkey.logging import log, warn
from neuralmonkey.logging import log

# pylint: enable=invalid-name
InitializerSpecs = List[Tuple[str, Callable]]
Expand Down Expand Up @@ -49,13 +49,6 @@ def __init__(self,
self._reuse = reuse is not None

if reuse is not None:
# pylint: disable=unidiomatic-typecheck
# Here we need an exact match of types
if type(self) != type(reuse):
warn("Warning: sharing parameters between model parts of "
"different types.")
# pylint: enable=unidiomatic-typecheck

if initializers is not None:
raise ValueError("Cannot use initializers in model part '{}' "
"that reuses variables from '{}'."
Expand Down
2 changes: 1 addition & 1 deletion neuralmonkey/readers/string_vector_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def process_line(line: str, lineno: int, path: str) -> np.ndarray:

return np.array(numbers, dtype=dtype)

def reader(files: List[str])-> Iterable[List[np.ndarray]]:
def reader(files: List[str]) -> Iterable[List[np.ndarray]]:
for path in files:
current_line = 0

Expand Down
13 changes: 12 additions & 1 deletion neuralmonkey/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ def main() -> None:
help="the configuration file of the experiment")
parser.add_argument("datasets", metavar="INI-TEST-DATASETS",
help="the configuration of the test datasets")
parser.add_argument("-s", "--set", type=str, metavar="SETTING",
action="append", dest="config_changes", default=[],
help="override an option in the configuration; the "
"syntax is [section.]option=value")
parser.add_argument("-v", "--var", type=str, metavar="VAR", default=[],
action="append", dest="config_vars",
help="set a variable in the configuration; the syntax "
"is var=value (shorthand for -s vars.var=value)")
parser.add_argument("--json", type=str, help="write the evaluation "
"results to this file in JSON format")
parser.add_argument("-g", "--grid", dest="grid", action="store_true",
Expand All @@ -37,7 +45,10 @@ def main() -> None:

datasets_model = load_runtime_config(args.datasets)

exp = Experiment(config_path=args.config)
args.config_changes.extend("vars.{}".format(s) for s in args.config_vars)

exp = Experiment(config_path=args.config,
config_changes=args.config_changes)
exp.build_model()
exp.load_variables(datasets_model.variables)

Expand Down
58 changes: 58 additions & 0 deletions neuralmonkey/runners/gradient_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import Dict, List, Union

import tensorflow as tf
from typeguard import check_argument_types

from neuralmonkey.runners.base_runner import BaseRunner
from neuralmonkey.model.model_part import GenericModelPart
from neuralmonkey.decoders.autoregressive import AutoregressiveDecoder
from neuralmonkey.decoders.classifier import Classifier
from neuralmonkey.trainers.generic_trainer import GenericTrainer

# pylint: disable=invalid-name
SupportedDecoder = Union[AutoregressiveDecoder, Classifier]
# pylint: enable=invalid-name


class GradientRunner(BaseRunner[GenericModelPart]):
"""Runner for fetching gradients computed over the dataset.

Gradient runner applies provided trainer on a desired dataset
and uses it to compute gradients over the gold data. It is currently
used to gather gradients for Elastic Weight Consolidation.

(https://arxiv.org/pdf/1612.00796.pdf)
"""

# pylint: disable=too-few-public-methods
class Executable(BaseRunner.Executable["GradientRunner"]):

def collect_results(self, results: List[Dict]) -> None:
assert len(results) == 1

for sess_result in results:
gradient_dict = {}
tensor_names = [
t.name for t in self.executor.fetches()["gradients"]]
for name, val in zip(tensor_names, sess_result["gradients"]):
gradient_dict[name] = val

self.set_runner_result(outputs=gradient_dict, losses=[])
# pylint: enable=too-few-public-methods

def __init__(self,
output_series: str,
decoder: SupportedDecoder,
trainer: GenericTrainer) -> None:
check_argument_types()
BaseRunner[GenericModelPart].__init__(
self, output_series, decoder)

self._gradients = trainer.gradients

def fetches(self) -> Dict[str, tf.Tensor]:
return {"gradients": [g[1] for g in self._gradients]}

@property
def loss_names(self) -> List[str]:
return []
20 changes: 16 additions & 4 deletions neuralmonkey/trainers/cross_entropy_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from neuralmonkey.trainers.generic_trainer import GenericTrainer
from neuralmonkey.trainers.objective import (
Objective, CostObjective, ObjectiveWeight)
from neuralmonkey.trainers.regularizers import (
Regularizer, L1Regularizer, L2Regularizer)


# for compatibility reasons
Expand All @@ -23,17 +25,28 @@ class CrossEntropyTrainer(GenericTrainer):
def __init__(self,
decoders: List[Any],
decoder_weights: List[ObjectiveWeight] = None,
l1_weight: float = 0.,
l2_weight: float = 0.,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Já bych tady tu L2 a L1 nechal jako syntactic sugar a v konftruktoru udělal regularizátora, co zkonstruuje ty regularizátoři a přidá je to listu.

clip_norm: float = None,
optimizer: tf.train.Optimizer = None,
regularizers: List[Regularizer] = None,
l1_weight: float = 0.,
l2_weight: float = 0.,
var_scopes: List[str] = None,
var_collection: str = None) -> None:
check_argument_types()

if decoder_weights is None:
decoder_weights = [None for _ in decoders]

if regularizers is None:
regularizers = []

if l1_weight > 0.:
regularizers.append(
L1Regularizer(name="train_l1", weight=l1_weight))
if l2_weight > 0.:
regularizers.append(
L2Regularizer(name="train_l2", weight=l2_weight))

if len(decoder_weights) != len(decoders):
raise ValueError(
"decoder_weights (length {}) do not match decoders (length {})"
Expand All @@ -45,9 +58,8 @@ def __init__(self,
GenericTrainer.__init__(
self,
objectives=objectives,
l1_weight=l1_weight,
l2_weight=l2_weight,
clip_norm=clip_norm,
optimizer=optimizer,
regularizers=regularizers,
var_scopes=var_scopes,
var_collection=var_collection)
Loading