From 44e8881da721ab91eca1e1456f21272f158f5e6c Mon Sep 17 00:00:00 2001 From: Frances Hartwell Date: Wed, 11 Oct 2023 09:05:48 -0400 Subject: [PATCH] update Par progress bar --- deepecho/models/par.py | 20 +++++++++++++++++--- tests/integration/test_par.py | 20 ++++++++++++++++++++ 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/deepecho/models/par.py b/deepecho/models/par.py index 0b06005..b728989 100644 --- a/deepecho/models/par.py +++ b/deepecho/models/par.py @@ -105,6 +105,7 @@ def __init__(self, epochs=128, sample_size=1, cuda=True, verbose=True): self.device = torch.device(device) self.verbose = verbose + self.loss_values = pd.DataFrame(columns=['Epoch', 'Loss']) LOGGER.info('%s instance created', self) @@ -321,9 +322,10 @@ def fit_sequences(self, sequences, context_types, data_types): self._model = PARNet(self._data_dims, self._ctx_dims).to(self.device) optimizer = torch.optim.Adam(self._model.parameters(), lr=1e-3) - iterator = range(self.epochs) + iterator = tqdm(range(self.epochs), disable=(not self.verbose)) if self.verbose: - iterator = tqdm(iterator) + pbar_description = 'Loss ({loss:.3f})' + iterator.set_description(pbar_description.format(loss=0)) X_padded, seq_len = torch.nn.utils.rnn.pad_packed_sequence(X) for epoch in iterator: @@ -333,8 +335,20 @@ def fit_sequences(self, sequences, context_types, data_types): optimizer.zero_grad() loss = self._compute_loss(X_padded[1:, :, :], Y_padded[:-1, :, :], seq_len) loss.backward() + + epoch_loss_df = pd.DataFrame({ + 'Epoch': [epoch], + 'Loss': [loss.item()] + }) + if not self.loss_values.empty: + self.loss_values = pd.concat( + [self.loss_values, epoch_loss_df] + ).reset_index(drop=True) + else: + self.loss_values = epoch_loss_df + if self.verbose: - iterator.set_description(f'Epoch {epoch +1} | Loss {loss.item()}') + iterator.set_description(pbar_description.format(loss=loss.item())) optimizer.step() diff --git a/tests/integration/test_par.py b/tests/integration/test_par.py index f633a35..4e172ec 100644 --- a/tests/integration/test_par.py +++ b/tests/integration/test_par.py @@ -35,6 +35,10 @@ def test_basic(self): model.fit_sequences(sequences, context_types, data_types) model.sample_sequence([]) + # Assert + assert set(model.loss_values.columns) == {'Epoch', 'Loss'} + assert len(model.loss_values) == 128 + def test_conditional(self): """Test the ``PARModel`` with conditional sampling.""" sequences = [ @@ -60,6 +64,10 @@ def test_conditional(self): model.fit_sequences(sequences, context_types, data_types) model.sample_sequence([0]) + # Assert + assert set(model.loss_values.columns) == {'Epoch', 'Loss'} + assert len(model.loss_values) == 128 + def test_mixed(self): """Test the ``PARModel`` with mixed input data.""" sequences = [ @@ -85,6 +93,10 @@ def test_mixed(self): model.fit_sequences(sequences, context_types, data_types) model.sample_sequence([0]) + # Assert + assert set(model.loss_values.columns) == {'Epoch', 'Loss'} + assert len(model.loss_values) == 128 + def test_count(self): """Test the PARModel with datatype ``count``.""" sequences = [ @@ -110,6 +122,10 @@ def test_count(self): model.fit_sequences(sequences, context_types, data_types) model.sample_sequence([0]) + # Assert + assert set(model.loss_values.columns) == {'Epoch', 'Loss'} + assert len(model.loss_values) == 128 + def test_variable_length(self): """Test ``PARModel`` with variable data length.""" sequences = [ @@ -134,3 +150,7 @@ def test_variable_length(self): model = PARModel() model.fit_sequences(sequences, context_types, data_types) model.sample_sequence([0]) + + # Assert + assert set(model.loss_values.columns) == {'Epoch', 'Loss'} + assert len(model.loss_values) == 128