diff --git a/heat/core/io.py b/heat/core/io.py index 975c368f8f..b029cd2bc6 100644 --- a/heat/core/io.py +++ b/heat/core/io.py @@ -1032,9 +1032,15 @@ def save_csv( if len(data.lshape) == 1: row = fmt.format(data.larray[i]) else: + if data.lshape[1] == 0: + break row = sep.join(fmt.format(item) for item in data.larray[i]) - if data.split is None or data.split == 0 or data.comm.rank == (data.comm.size - 1): + if ( + data.split is None + or data.split == 0 + or displs[data.comm.rank] + data.lshape[1] == data.shape[1] + ): row = row + "\n" else: row = row + sep diff --git a/heat/core/tests/test_io.py b/heat/core/tests/test_io.py index b3747e1458..43e6212019 100644 --- a/heat/core/tests/test_io.py +++ b/heat/core/tests/test_io.py @@ -216,6 +216,11 @@ def test_save_csv(self): comparison = ht.load_csv(self.CSV_OUT_PATH, dtype=data.dtype) self.assertTrue(ht.max(data - comparison.reshape(data.shape)).item() < 0.0000001) + data = ht.random.randint(10, 100, (50, 2), split=1) + data.save(self.CSV_OUT_PATH) + comparison = ht.load_csv(self.CSV_OUT_PATH, dtype=data.dtype) + self.assertTrue((data == comparison).all().item()) + def test_load_exception(self): # correct extension, file does not exist if ht.io.supports_hdf5():