diff --git a/src/colbert/index_updater.py b/src/colbert/index_updater.py index 84ba2ff2..4e2c40ba 100644 --- a/src/colbert/index_updater.py +++ b/src/colbert/index_updater.py @@ -390,7 +390,7 @@ def _add_pid_to_ivf(self, partitions, pid): assert sum(new_ivf_lengths) == len(new_ivf) # Replace the current ivf with new_ivf - self.curr_ivf = torch.tensor(new_ivf) + self.curr_ivf = torch.tensor(new_ivf, dtype=torch.int32) self.curr_ivf_lengths = torch.tensor(new_ivf_lengths) def _write_to_last_chunk(self, pid_start, pid_end, emb_start, emb_end):