diff --git a/modyn/models/rho_loss_twin_model/rho_loss_twin_model.py b/modyn/models/rho_loss_twin_model/rho_loss_twin_model.py index 30ba99f1f..56817d475 100644 --- a/modyn/models/rho_loss_twin_model/rho_loss_twin_model.py +++ b/modyn/models/rho_loss_twin_model/rho_loss_twin_model.py @@ -2,6 +2,7 @@ from typing import Any, Optional import torch +import copy from modyn.utils import dynamic_module_import from torch import nn @@ -26,8 +27,9 @@ def __init__(self, model_configuration: dict[str, Any], device: str, amp: bool) # we only need the inner model, not the wrapper self._models = nn.ModuleList( [ - model_handler(rho_model_config, device, amp).model, - model_handler(rho_model_config, device, amp).model, + # some models change the model config dict during initialization + model_handler(copy.deepcopy(rho_model_config), device, amp).model, + model_handler(copy.deepcopy(rho_model_config), device, amp).model, ] ) self._models_seen_ids: list[set[int]] = [set(), set()]