Skip to content

Commit b67804f

Browse files
committed
fix callback/
1 parent b9bc6c2 commit b67804f

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

src/cellflow/training/_callbacks.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,13 @@ def on_log_iteration(
310310
valid_true_data_decoded = jtu.tree_map(self.reconstruct_data, valid_true_data)
311311
predicted_data_decoded = jtu.tree_map(self.reconstruct_data, valid_pred_data)
312312

313-
metrics = super().on_log_iteration(valid_true_data_decoded, predicted_data_decoded)
313+
metrics = super().on_log_iteration(
314+
valid_source_data={},
315+
valid_true_data=valid_true_data_decoded,
316+
valid_pred_data=predicted_data_decoded,
317+
solver=solver,
318+
)
319+
314320
metrics = {f"{self.log_prefix}{k}": v for k, v in metrics.items()}
315321
return metrics
316322

@@ -382,7 +388,12 @@ def on_log_iteration(
382388
valid_true_data_decoded = jtu.tree_map(self.reconstruct_data, valid_true_data_in_anndata)
383389
predicted_data_decoded = jtu.tree_map(self.reconstruct_data, predicted_data_in_anndata)
384390

385-
metrics = super().on_log_iteration(valid_true_data_decoded, predicted_data_decoded)
391+
metrics = super().on_log_iteration(
392+
valid_source_data={},
393+
valid_true_data=valid_true_data_decoded,
394+
valid_pred_data=predicted_data_decoded,
395+
solver=solver,
396+
)
386397
metrics = {f"{self.log_prefix}{k}": v for k, v in metrics.items()}
387398
return metrics
388399

0 commit comments

Comments
 (0)