Skip to content

Commit

Permalink
Bugfix: nprocs > shape[1]
Browse files Browse the repository at this point in the history
Having more processes than chunks in split 1 did not work. Rather than 
checking whether we are the last (overall) rank, we check whether we 
have the last chunk of data and don't write anything if we have no data. 
Last chunk is relevant to distinguish newline or separator addition at 
the end of our buffer.
  • Loading branch information
bhagemeier committed Mar 28, 2022
1 parent 4e7a23c commit a8740cb
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
8 changes: 7 additions & 1 deletion heat/core/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -1032,9 +1032,15 @@ def save_csv(
if len(data.lshape) == 1:
row = fmt.format(data.larray[i])
else:
if data.lshape[1] == 0:
break
row = sep.join(fmt.format(item) for item in data.larray[i])

if data.split is None or data.split == 0 or data.comm.rank == (data.comm.size - 1):
if (
data.split is None
or data.split == 0
or displs[data.comm.rank] + data.lshape[1] == data.shape[1]
):
row = row + "\n"
else:
row = row + sep
Expand Down
5 changes: 5 additions & 0 deletions heat/core/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,11 @@ def test_save_csv(self):
comparison = ht.load_csv(self.CSV_OUT_PATH, dtype=data.dtype)
self.assertTrue(ht.max(data - comparison.reshape(data.shape)).item() < 0.0000001)

data = ht.random.randint(10, 100, (50, 2), split=1)
data.save(self.CSV_OUT_PATH)
comparison = ht.load_csv(self.CSV_OUT_PATH, dtype=data.dtype)
self.assertTrue((data == comparison).all().item())

def test_load_exception(self):
# correct extension, file does not exist
if ht.io.supports_hdf5():
Expand Down

0 comments on commit a8740cb

Please sign in to comment.