Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
frances-h committed Sep 26, 2023
1 parent 60a65f0 commit 2a6fd65
Showing 1 changed file with 6 additions and 9 deletions.
15 changes: 6 additions & 9 deletions ctgan/synthesizers/ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,14 +301,6 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
DeprecationWarning
)

epoch_iterator = range(epochs)
if self._verbose:
progress_bar = tqdm(range(epochs))
epoch_iterator = progress_bar

description = 'Gen. ({gen:.2f}) | Discrim. ({dis:.2f})'
progress_bar.set_description(description.format(gen=0, dis=0))

self._transformer = DataTransformer()
self._transformer.fit(train_data, discrete_columns)

Expand Down Expand Up @@ -348,6 +340,11 @@ def fit(self, train_data, discrete_columns=(), epochs=None):

self.loss_values = pd.DataFrame(columns=['Epoch', 'Generator Loss', 'Distriminator Loss'])

epoch_iterator = tqdm(range(epochs), disable=(not self._verbose))
if self._verbose:
description = 'Gen. ({gen:.2f}) | Discrim. ({dis:.2f})'
epoch_iterator.set_description(description.format(gen=0, dis=0))

steps_per_epoch = max(len(train_data) // self._batch_size, 1)
for i in epoch_iterator:
for id_ in range(steps_per_epoch):
Expand Down Expand Up @@ -441,7 +438,7 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
self.loss_values = epoch_loss_df

if self._verbose:
progress_bar.set_description(
epoch_iterator.set_description(
description.format(gen=generator_loss, dis=discriminator_loss)
)

Expand Down

0 comments on commit 2a6fd65

Please sign in to comment.