Skip to content

Aggregate logs in evaluate #483

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions bayesflow/approximators/backend_approximators/jax_approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,148 @@

from bayesflow.utils import filter_kwargs

from keras.src.backend.jax.trainer import JAXEpochIterator
from keras.src import callbacks as callbacks_module


class JAXApproximator(keras.Model):
# noinspection PyMethodOverriding
def compute_metrics(self, *args, **kwargs) -> dict[str, jax.Array]:
# implemented by each respective architecture
raise NotImplementedError

def evaluate(
self,
x=None,
y=None,
batch_size=None,
verbose="auto",
sample_weight=None,
steps=None,
callbacks=None,
return_dict=False,
aggregate=True,
**kwargs,
):
self._assert_compile_called("evaluate")
# TODO: respect compiled trainable state
use_cached_eval_dataset = kwargs.pop("_use_cached_eval_dataset", False)
if kwargs:
raise ValueError(f"Arguments not recognized: {kwargs}")

Check warning on line 33 in bayesflow/approximators/backend_approximators/jax_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/jax_approximator.py#L33

Added line #L33 was not covered by tests

if use_cached_eval_dataset:
epoch_iterator = self._eval_epoch_iterator
else:
# Create an iterator that yields batches of
# input/target data.
epoch_iterator = JAXEpochIterator(
x=x,
y=y,
sample_weight=sample_weight,
batch_size=batch_size,
steps_per_epoch=steps,
shuffle=False,
steps_per_execution=self.steps_per_execution,
)

self._symbolic_build(iterator=epoch_iterator)
epoch_iterator.reset()

# Container that configures and calls callbacks.
if not isinstance(callbacks, callbacks_module.CallbackList):
callbacks = callbacks_module.CallbackList(
callbacks,
add_progbar=verbose != 0,
verbose=verbose,
epochs=1,
steps=epoch_iterator.num_batches,
model=self,
)
self._record_training_state_sharding_spec()

self.make_test_function()
self.stop_evaluating = False
callbacks.on_test_begin()
logs = {}
total_steps = 0
self.reset_metrics()

def _aggregate_fn(_logs, _step_logs):
if not _logs:
return _step_logs

return keras.tree.map_structure(keras.ops.add, _logs, _step_logs)

def _reduce_fn(_logs, _total_steps):
if _total_steps == 0:
return _logs

Check warning on line 80 in bayesflow/approximators/backend_approximators/jax_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/jax_approximator.py#L80

Added line #L80 was not covered by tests

def _div(val):
return val / _total_steps

return keras.tree.map_structure(_div, _logs)

self._jax_state_synced = True
with epoch_iterator.catch_stop_iteration():
for step, iterator in epoch_iterator:
callbacks.on_test_batch_begin(step)

total_steps += 1

if self._jax_state_synced:
# The state may have been synced by a callback.
state = self._get_jax_state(
trainable_variables=True,
non_trainable_variables=True,
metrics_variables=True,
purge_model_variables=True,
)
self._jax_state_synced = False

step_logs, state = self.test_function(state, iterator)
(
trainable_variables,
non_trainable_variables,
metrics_variables,
) = state

if aggregate:
logs = _aggregate_fn(logs, step_logs)
else:
logs = step_logs

Check warning on line 114 in bayesflow/approximators/backend_approximators/jax_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/jax_approximator.py#L114

Added line #L114 was not covered by tests

# Setting _jax_state enables callbacks to force a state sync
# if they need to.
self._jax_state = {
# I wouldn't recommend modifying non-trainable model state
# during evaluate(), but it's allowed.
"trainable_variables": trainable_variables,
"non_trainable_variables": non_trainable_variables,
"metrics_variables": metrics_variables,
}

# Dispatch callbacks. This takes care of async dispatch.
callbacks.on_test_batch_end(step, step_logs)

if self.stop_evaluating:
break

Check warning on line 130 in bayesflow/approximators/backend_approximators/jax_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/jax_approximator.py#L130

Added line #L130 was not covered by tests

if aggregate:
logs = _reduce_fn(logs, total_steps)

# Reattach state back to model (if not already done by a callback).
self.jax_state_sync()

logs = self._get_metrics_result_or_logs(logs)
callbacks.on_test_end(logs)
self._jax_state = None
if not use_cached_eval_dataset:
# Only clear sharding if evaluate is not called from `fit`.
self._clear_jax_state_sharding()
if return_dict:
return logs
return self._flatten_metrics_in_order(logs)

Check warning on line 146 in bayesflow/approximators/backend_approximators/jax_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/jax_approximator.py#L146

Added line #L146 was not covered by tests

def stateless_compute_metrics(
self,
trainable_variables: any,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,112 @@

from bayesflow.utils import filter_kwargs

from keras.src.backend.tensorflow.trainer import TFEpochIterator
from keras.src import callbacks as callbacks_module


class TensorFlowApproximator(keras.Model):
# noinspection PyMethodOverriding
def compute_metrics(self, *args, **kwargs) -> dict[str, tf.Tensor]:
# implemented by each respective architecture
raise NotImplementedError

def evaluate(
self,
x=None,
y=None,
batch_size=None,
verbose="auto",
sample_weight=None,
steps=None,
callbacks=None,
return_dict=False,
aggregate=True,
**kwargs,
):
self._assert_compile_called("evaluate")
# TODO: respect compiled trainable state
use_cached_eval_dataset = kwargs.pop("_use_cached_eval_dataset", False)
if kwargs:
raise ValueError(f"Arguments not recognized: {kwargs}")

Check warning on line 33 in bayesflow/approximators/backend_approximators/tensorflow_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/tensorflow_approximator.py#L33

Added line #L33 was not covered by tests

if use_cached_eval_dataset:
epoch_iterator = self._eval_epoch_iterator
else:
# Create an iterator that yields batches of input/target data.
epoch_iterator = TFEpochIterator(
x=x,
y=y,
sample_weight=sample_weight,
batch_size=batch_size,
steps_per_epoch=steps,
shuffle=False,
distribute_strategy=self.distribute_strategy,
steps_per_execution=self.steps_per_execution,
)

self._maybe_symbolic_build(iterator=epoch_iterator)
epoch_iterator.reset()

# Container that configures and calls callbacks.
if not isinstance(callbacks, callbacks_module.CallbackList):
callbacks = callbacks_module.CallbackList(
callbacks,
add_progbar=verbose != 0,
verbose=verbose,
epochs=1,
steps=epoch_iterator.num_batches,
model=self,
)

self.make_test_function()
self.stop_evaluating = False
callbacks.on_test_begin()
logs = {}
total_steps = 0
self.reset_metrics()

def _aggregate_fn(_logs, _step_logs):
if not _logs:
return _step_logs

return keras.tree.map_structure(keras.ops.add, _logs, _step_logs)

def _reduce_fn(_logs, _total_steps):
if _total_steps == 0:
return _logs

Check warning on line 79 in bayesflow/approximators/backend_approximators/tensorflow_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/tensorflow_approximator.py#L79

Added line #L79 was not covered by tests

def _div(val):
return val / _total_steps

return keras.tree.map_structure(_div, _logs)

with epoch_iterator.catch_stop_iteration():
for step, iterator in epoch_iterator:
callbacks.on_test_batch_begin(step)
total_steps += 1

step_logs = self.test_function(iterator)

if aggregate:
logs = _aggregate_fn(logs, step_logs)
else:
logs = step_logs

Check warning on line 96 in bayesflow/approximators/backend_approximators/tensorflow_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/tensorflow_approximator.py#L96

Added line #L96 was not covered by tests

callbacks.on_test_batch_end(step, step_logs)
if self.stop_evaluating:
break

Check warning on line 100 in bayesflow/approximators/backend_approximators/tensorflow_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/tensorflow_approximator.py#L100

Added line #L100 was not covered by tests

if aggregate:
logs = _reduce_fn(logs, total_steps)

logs = self._get_metrics_result_or_logs(logs)
callbacks.on_test_end(logs)

if return_dict:
return logs
return self._flatten_metrics_in_order(logs)

Check warning on line 110 in bayesflow/approximators/backend_approximators/tensorflow_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/tensorflow_approximator.py#L110

Added line #L110 was not covered by tests

def test_step(self, data: dict[str, any]) -> dict[str, tf.Tensor]:
kwargs = filter_kwargs(data | {"stage": "validation"}, self.compute_metrics)
return self.compute_metrics(**kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,111 @@

from bayesflow.utils import filter_kwargs

from keras.src.backend.torch.trainer import TorchEpochIterator
from keras.src import callbacks as callbacks_module


class TorchApproximator(keras.Model):
# noinspection PyMethodOverriding
def compute_metrics(self, *args, **kwargs) -> dict[str, torch.Tensor]:
# implemented by each respective architecture
raise NotImplementedError

def evaluate(
self,
x=None,
y=None,
batch_size=None,
verbose="auto",
sample_weight=None,
steps=None,
callbacks=None,
return_dict=False,
aggregate=True,
**kwargs,
):
# TODO: respect compiled trainable state
use_cached_eval_dataset = kwargs.pop("_use_cached_eval_dataset", False)
if kwargs:
raise ValueError(f"Arguments not recognized: {kwargs}")

Check warning on line 32 in bayesflow/approximators/backend_approximators/torch_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/torch_approximator.py#L32

Added line #L32 was not covered by tests

if use_cached_eval_dataset:
epoch_iterator = self._eval_epoch_iterator
else:
# Create an iterator that yields batches of input/target data.
epoch_iterator = TorchEpochIterator(
x=x,
y=y,
sample_weight=sample_weight,
batch_size=batch_size,
steps_per_epoch=steps,
shuffle=False,
steps_per_execution=self.steps_per_execution,
)

self._symbolic_build(iterator=epoch_iterator)
epoch_iterator.reset()

# Container that configures and calls callbacks.
if not isinstance(callbacks, callbacks_module.CallbackList):
callbacks = callbacks_module.CallbackList(
callbacks,
add_progbar=verbose != 0,
verbose=verbose,
epochs=1,
steps=epoch_iterator.num_batches,
model=self,
)

# Switch the torch Module back to testing mode.
self.eval()

self.make_test_function()
self.stop_evaluating = False
callbacks.on_test_begin()
logs = {}
total_steps = 0
self.reset_metrics()

def _aggregate_fn(_logs, _step_logs):
if not _logs:
return _step_logs

return keras.tree.map_structure(keras.ops.add, _logs, _step_logs)

def _reduce_fn(_logs, _total_steps):
if _total_steps == 0:
return _logs

Check warning on line 80 in bayesflow/approximators/backend_approximators/torch_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/torch_approximator.py#L80

Added line #L80 was not covered by tests

def _div(val):
return val / _total_steps

return keras.tree.map_structure(_div, _logs)

for step, data in epoch_iterator:
callbacks.on_test_batch_begin(step)
total_steps += 1
step_logs = self.test_function(data)

if aggregate:
logs = _aggregate_fn(logs, step_logs)
else:
logs = step_logs

Check warning on line 95 in bayesflow/approximators/backend_approximators/torch_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/torch_approximator.py#L95

Added line #L95 was not covered by tests

callbacks.on_test_batch_end(step, step_logs)
if self.stop_evaluating:
break

Check warning on line 99 in bayesflow/approximators/backend_approximators/torch_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/torch_approximator.py#L99

Added line #L99 was not covered by tests

if aggregate:
logs = _reduce_fn(logs, total_steps)

logs = self._get_metrics_result_or_logs(logs)
callbacks.on_test_end(logs)

if return_dict:
return logs
return self._flatten_metrics_in_order(logs)

Check warning on line 109 in bayesflow/approximators/backend_approximators/torch_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/torch_approximator.py#L109

Added line #L109 was not covered by tests

def test_step(self, data: dict[str, any]) -> dict[str, torch.Tensor]:
kwargs = filter_kwargs(data | {"stage": "validation"}, self.compute_metrics)
return self.compute_metrics(**kwargs)
Expand Down