Open
Description
When calculating the total steps, shouldn't we use number of batches * epoch size
? In this case, it would be self.total_steps = (len(train_loader.dataset) // tb_size) * ab_size
instead of self.total_steps = (len(train_loader.dataset) // tb_size) // ab_size
.
Please fix me if anywhere is wrong.