diff --git a/ctgan/synthesizers/ctgan.py b/ctgan/synthesizers/ctgan.py index bb732760..82dd3aba 100644 --- a/ctgan/synthesizers/ctgan.py +++ b/ctgan/synthesizers/ctgan.py @@ -301,14 +301,6 @@ def fit(self, train_data, discrete_columns=(), epochs=None): DeprecationWarning ) - epoch_iterator = range(epochs) - if self._verbose: - progress_bar = tqdm(range(epochs)) - epoch_iterator = progress_bar - - description = 'Gen. ({gen:.2f}) | Discrim. ({dis:.2f})' - progress_bar.set_description(description.format(gen=0, dis=0)) - self._transformer = DataTransformer() self._transformer.fit(train_data, discrete_columns) @@ -348,6 +340,11 @@ def fit(self, train_data, discrete_columns=(), epochs=None): self.loss_values = pd.DataFrame(columns=['Epoch', 'Generator Loss', 'Distriminator Loss']) + epoch_iterator = tqdm(range(epochs), disable=(not self._verbose)) + if self._verbose: + description = 'Gen. ({gen:.2f}) | Discrim. ({dis:.2f})' + epoch_iterator.set_description(description.format(gen=0, dis=0)) + steps_per_epoch = max(len(train_data) // self._batch_size, 1) for i in epoch_iterator: for id_ in range(steps_per_epoch): @@ -441,7 +438,7 @@ def fit(self, train_data, discrete_columns=(), epochs=None): self.loss_values = epoch_loss_df if self._verbose: - progress_bar.set_description( + epoch_iterator.set_description( description.format(gen=generator_loss, dis=discriminator_loss) )