Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix device for initial RHO-LOSS tensor (#525)
The initial tensor resides on CPU. For me, it failed when running RHO-LOSS, since the very first ` self.rho_loss = torch.cat([self.rho_loss, training_loss - irreducible_loss]).to(training_loss.dtype)` was on two different devices, the CPU (initial tensor) and the GPU (training loss and IR loss). This PR fixes that by moving the initial tensor to the correct device.
- Loading branch information