@@ -61,8 +61,6 @@ def __init__(
6161
6262 :param AbstractProblem problem: The problem to be solved.
6363 :param torch.nn.Module models: The neural network models to be used.
64- :param int ensemble_dim: The dimension along which the ensemble
65- outputs are stacked. Default is 0.
6664 :param Optimizer optimizer: The optimizer to be used.
6765 If ``None``, the :class:`torch.optim.Adam` optimizer is used.
6866 Default is ``None``.
@@ -73,6 +71,8 @@ def __init__(
7371 If ``None``, no weighting schema is used. Default is ``None``.
7472 :param bool use_lt: If ``True``, the solver uses LabelTensors as input.
7573 Default is ``True``.
74+ :param int ensemble_dim: The dimension along which the ensemble
75+ outputs are stacked. Default is 0.
7676 """
7777 super ().__init__ (
7878 problem , models , optimizers , schedulers , weighting , use_lt
@@ -90,7 +90,8 @@ def forward(self, x, ensemble_idx=None):
9090
9191 :param LabelTensor x: The input tensor to the models.
9292 :param int ensemble_idx: Optional index to select a specific
93- model from the ensemble.
93+ model from the ensemble. If ``None`` results for all models are
94+ stacked in ``ensemble_dim`` dimension. Default is ``None``.
9495 :return: The output of the selected model or the stacked
9596 outputs from all models.
9697 :rtype: LabelTensor
@@ -100,7 +101,7 @@ def forward(self, x, ensemble_idx=None):
100101 return self .models [ensemble_idx ].forward (x )
101102 # otherwise return the stacked output
102103 return torch .stack (
103- [self .forward (x , idx ) for idx in range (self .num_ensembles )],
104+ [self .forward (x , idx ) for idx in range (self .num_ensemble )],
104105 dim = self .ensemble_dim ,
105106 )
106107
@@ -125,8 +126,9 @@ def training_step(self, batch):
125126 # perform backpropagation
126127 self .manual_backward (loss )
127128 # optimize
128- for opt in self .optimizers :
129+ for opt , sched in zip ( self .optimizers , self . schedulers ) :
129130 opt .instance .step ()
131+ sched .instance .step ()
130132 return loss
131133
132134 @property
@@ -140,7 +142,7 @@ def ensemble_dim(self):
140142 return self ._ensemble_dim
141143
142144 @property
143- def num_ensembles (self ):
145+ def num_ensemble (self ):
144146 """
145147 The number of models in the ensemble.
146148
0 commit comments