Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] refactor data flow and engine library #1054

Merged
merged 26 commits into from
Aug 31, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
datapreprocessor, formatting
  • Loading branch information
zengyh1900 committed Aug 24, 2022
commit ab4009b59f8abcad5db38aa5b7d45070ade63ec3
2 changes: 1 addition & 1 deletion mmedit/datasets/transforms/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def transform(self, results: dict) -> dict:

data_sample.set_metainfo(metainfo=metainfo)

packed_results['data_sample'] = data_sample
packed_results['data_samples'] = data_sample

return packed_results

Expand Down
4 changes: 2 additions & 2 deletions mmedit/evaluation/metrics/matting.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(
super().__init__(**kwargs)

def process(self, data_batch: Sequence[dict],
predictions: Sequence[dict]) -> None:
data_samples: Sequence[dict]) -> None:
"""Process one batch of data and predictions

Args:
Expand All @@ -93,7 +93,7 @@ def process(self, data_batch: Sequence[dict],
predictions (Sequence[dict]): A batch of outputs from
the model.
"""
for data, prediction in zip(data_batch, predictions):
for data, prediction in zip(data_batch, data_samples):
pred_alpha, gt_alpha, _ = _fetch_data_and_check(data, prediction)

# divide by 1000 to reduce the magnitude of the result
Expand Down
14 changes: 9 additions & 5 deletions mmedit/models/data_preprocessors/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ def forward(
model input.
"""

inputs, batch_data_samples = self.collate_data(data)
# inputs, batch_data_samples = self.collate_data(data)
data = super().forward(data=data, training=training)
inputs, batch_data_samples = data['inputs'], data['data_samples']

# Check if input is normalized to [0, 1]
self.norm_input_flag = (inputs[0].max() <= 1)
Expand All @@ -113,17 +115,19 @@ def forward(
for _input in inputs]

# Pad and stack Tensor.
batch_inputs, self.padded_sizes = stack_batch(inputs,
self.pad_size_divisor,
self.pad_args)
inputs, self.padded_sizes = stack_batch(inputs, self.pad_size_divisor,
self.pad_args)

if training:
for data_sample in batch_data_samples:
data_sample.gt_img.data = (
(data_sample.gt_img.data - self.outputs_mean[0]) /
self.outputs_std[0])

return batch_inputs, batch_data_samples
data['inputs'] = inputs
data['data_samples'] = batch_data_samples
# return inputs, batch_data_samples
return data

def destructor(self, batch_tensor: torch.Tensor):
"""Destructor of data processor.
Expand Down
11 changes: 7 additions & 4 deletions mmedit/models/data_preprocessors/mattor_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def forward(self,
# Image may of different size when testing
assert len(data) == 1, ('only batch_size=1 '
'is supported for testing.')

data = super().forward(data, training=training)
images, trimaps, batch_data_samples = self.collate_data(data)

batch_images = self._proc_inputs(images)
Expand Down Expand Up @@ -164,16 +164,19 @@ def forward(self,
# N, (4/6), H, W
batch_inputs = torch.cat((batch_images, batch_trimaps), dim=1)

return batch_inputs, batch_data_samples
data['inputs'] = batch_inputs
data['data_samples'] = batch_data_samples
# return batch_inputs, batch_data_samples
return data

def collate_data(self, data: Sequence[dict]) -> Tuple[list, list, list]:
"""Collating and moving data to the target device.

See base class ``BaseDataPreprocessor`` for detailed information.
"""
inputs = [data_['inputs'] for data_ in data]
trimaps = [data_['data_sample'].trimap.data for data_ in data]
batch_data_samples = [data_['data_sample'] for data_ in data]
trimaps = [data_['data_samples'].trimap.data for data_ in data]
batch_data_samples = [data_['data_samples'] for data_ in data]

# Move data from CPU to corresponding device.
inputs = [_input.to(self.device) for _input in inputs]
Expand Down