Skip to content

Commit 7e770a0

Browse files
Return self in _fit()
1 parent de4ae71 commit 7e770a0

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

autoPyTorch/pipeline/components/training/trainer/base_trainer_choice.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> autoPyTorchCom
203203

204204
return cast(autoPyTorchComponent, self.choice)
205205

206-
def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> torch.nn.Module:
206+
def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoice':
207207
"""
208208
Fits a component by using an input dictionary with pre-requisites
209209
@@ -336,7 +336,7 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> torch.nn.Modu
336336
# Tag as fitted
337337
self.fitted_ = True
338338

339-
return X['network'].state_dict()
339+
return self
340340

341341
def early_stop_handler(self, X: Dict[str, Any]) -> bool:
342342
"""

0 commit comments

Comments
 (0)