Skip to content

Change the default value of the label indices from [1] to [1:] #265

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions mindocr/utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class Evaluator:
metrics: metrics to evaluate network performance
pred_cast_fp32: whehter to cast network prediction to float 32. Set True if AMP is used.
input_indices: The indices of the data tuples which will be fed into the network. If it is None, then the first item will be fed only.
label_indices: The indices of the data tuples which will be marked as label. If it is None, then the second item will be marked as label.
label_indices: The indices of the data tuples which will be marked as label. If it is None, then the remaining items will be marked as label.
meta_data_indices: The indices for the data tuples which will be marked as meta data. If it is None, then the item indices not in input or label indices are marked as meta data.
"""

Expand Down Expand Up @@ -96,7 +96,7 @@ def eval(self):
if self.label_indices is not None:
gt = [data[x] for x in self.label_indices]
else:
gt = [data[1]]
gt = data[1:]

net_preds = self.net(*inputs)

Expand All @@ -113,7 +113,7 @@ def eval(self):
else:
# assume the indices not in input_indices or label_indices are all meta_data_indices
input_indices = set(self.input_indices) if self.input_indices is not None else {0}
label_indices = set(self.label_indices) if self.label_indices is not None else {1}
label_indices = set(self.label_indices) if self.label_indices is not None else set(range(1, len(data), 1))
meta_data_indices = sorted(set(range(len(data))) - input_indices - label_indices)
meta_info = [data[x] for x in meta_data_indices]

Expand Down
8 changes: 4 additions & 4 deletions mindocr/utils/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class NetWithLossWrapper(nn.Cell):
net (nn.Cell): network
loss_fn: loss function
input_indices: The indices of the data tuples which will be fed into the network. If it is None, then the first item will be fed only.
label_indices: The indices of the data tuples which will be fed into the loss function. If it is None, then the second item will be fed only.
label_indices: The indices of the data tuples which will be fed into the loss function. If it is None, then the remaining items will be fed.
'''
def __init__(self, net, loss_fn, pred_cast_fp32=False, input_indices=None, label_indices=None):
super().__init__(auto_prefix=False)
Expand Down Expand Up @@ -44,7 +44,7 @@ def construct(self, *args):
pred = [F.cast(p, mstype.float32) for p in pred]

if self.label_indices is None:
loss_val = self._loss_fn(pred, args[1])
loss_val = self._loss_fn(pred, *args[1:])
else:
loss_val = self._loss_fn(pred, *select_inputs_by_indices(args, self.label_indices))

Expand All @@ -60,7 +60,7 @@ class NetWithEvalWrapper(nn.Cell):
net (nn.Cell): network
loss_fn: loss function, if None, will not compute loss for evaluation dataset
input_indices: The indices of the data tuples which will be fed into the network. If it is None, then the first item will be fed only.
label_indices: The indices of the data tuples which will be fed into the loss function. If it is None, then the second item will be fed only.
label_indices: The indices of the data tuples which will be fed into the loss function. If it is None, then the remaining items will be fed.
'''
def __init__(self, net, loss_fn=None, input_indices=None, label_indices=None):
super().__init__(auto_prefix=False)
Expand All @@ -84,7 +84,7 @@ def construct(self, *args):
pred = self._net(*select_inputs_by_indices(args, self.input_indices))

if self.label_indices is None:
labels = [args[1]]
labels = args[1:]
else:
labels = select_inputs_by_indices(args, self.label_indices)

Expand Down