Skip to content

Commit

Permalink
Make CSV writer naming consistent (step instead of batch) (pytorch#260)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#260

As title

Reviewed By: edward-io

Differential Revision: D40740658

fbshipit-source-id: 6711bf758c2b34fc4790699f9afe567dbe77ab4d
  • Loading branch information
ananthsub authored and facebook-github-bot committed Oct 26, 2022
1 parent bc64893 commit 72e13df
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
6 changes: 3 additions & 3 deletions tests/runner/callbacks/test_csv_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


class CustomCSVWriter(BaseCSVWriter):
def get_batch_output_rows(
def get_step_output_rows(
self,
state: State,
unit: PredictUnit[TPredictData],
Expand All @@ -31,7 +31,7 @@ def get_batch_output_rows(


class CustomCSVWriterSingleRow(BaseCSVWriter):
def get_batch_output_rows(
def get_step_output_rows(
self,
state: State,
unit: PredictUnit[TPredictData],
Expand Down Expand Up @@ -100,7 +100,7 @@ def test_csv_writer_with_no_output_rows_def(self) -> None:
dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
state = init_predict_state(dataloader=dataloader)

# Throw exception because get_batch_output_rows is not defined.
# Throw exception because get_step_output_rows is not defined.
with self.assertRaises(TypeError):
csv_callback = BaseCSVWriter(
header_row=_HEADER_ROW, dir_path="", filename=_FILENAME
Expand Down
14 changes: 7 additions & 7 deletions torchtnt/runner/callbacks/base_csv_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class BaseCSVWriter(Callback, ABC):
This callback provides an interface to simplify writing outputs during prediction
into a CSV file. This callback must be extended with an implementation for
``get_batch_output_rows`` to write the desired outputs as rows in the CSV file.
``get_step_output_rows`` to write the desired outputs as rows in the CSV file.
By default, outputs at each step across all processes will be written into the same CSV file.
The outputs in each row is a a list of strings, and should match
Expand Down Expand Up @@ -53,7 +53,7 @@ def __init__(
self._writer: csv._writer = csv.writer(self._file, delimiter=delimiter)

@abstractmethod
def get_batch_output_rows(
def get_step_output_rows(
self,
state: State,
unit: TPredictUnit,
Expand All @@ -69,15 +69,15 @@ def on_predict_start(self, state: State, unit: TPredictUnit) -> None:
def on_predict_step_end(self, state: State, unit: TPredictUnit) -> None:
assert state.predict_state is not None
step_output = state.predict_state.step_output
batch_output_rows = self.get_batch_output_rows(state, unit, step_output)
output_rows = self.get_step_output_rows(state, unit, step_output)

# Check whether the first item is a list or not
if len(batch_output_rows) > 0:
if isinstance(batch_output_rows[0], list):
for row in batch_output_rows:
if len(output_rows) > 0:
if isinstance(output_rows[0], list):
for row in output_rows:
self._writer.writerow(row)
else:
self._writer.writerow(batch_output_rows)
self._writer.writerow(output_rows)

def on_predict_end(self, state: State, unit: TPredictUnit) -> None:
self._file.flush()
Expand Down

0 comments on commit 72e13df

Please sign in to comment.