Skip to content

Commit

Permalink
add transform.LabelShift support framework (intel#1035)
Browse files Browse the repository at this point in the history
  • Loading branch information
changwangss authored Jul 1, 2022
1 parent 62e3d39 commit 6526970
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/transform.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ Neural Compressor supports built-in preprocessing methods on different framework
| Cast(dtype) | **dtype** (str, default ='float32') :The target data type | Convert image to given dtype | Cast: <br> &ensp;&ensp; dtype: float32 |
| AlignImageChannel(dim) | **dim** (int): The channel number of result image | Align image channel, now just support [H,W,4]->[H,W,3] and [H,W,3]->[H,W], input image must be PIL Image. <br> This transform is going to be deprecated. | AlignImageChannel: <br> &ensp;&ensp; dim: 3 |
| ResizeWithRatio(min_dim, max_dim, padding) | **min_dim** (int, default=800): Resizes the image such that its smaller dimension == min_dim <br> **max_dim** (int, default=1365): Ensures that the image longest side does not exceed this value <br> **padding** (bool, default=False): If true, pads image with zeros so its size is max_dim x max_dim | Resize image with aspect ratio and pad it to max shape(optional). If the image is padded, the label will be processed at the same time. The input image should be np.array. | ResizeWithRatio: <br> &ensp;&ensp; min_dim: 800 <br> &ensp;&ensp; max_dim: 1365 <br> &ensp;&ensp; padding: True |
| LabelShift(label_shift) | **label_shift**(int, default=0): number of label shift | Convert label to label - label_shift | LabelShift: <br> &ensp;&ensp; label_shift: 0 |

### MXNet

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def __call__(self, sample):
image = tf.dtypes.cast(image, dtype=self.dtype_map[self.dtype])
return image, label

@transform_registry(transform_type="LabelShift", \
process="postprocess", framework="tensorflow, onnxrt_qlinearops, onnxrt_integerops, engine")
@transform_registry(transform_type="LabelShift", process="postprocess", \
framework="pytorch, tensorflow, onnxrt_qlinearops, onnxrt_integerops, engine")
class LabelShift(BaseTransform):
"""Convert label to label - label_shift.
Expand Down

0 comments on commit 6526970

Please sign in to comment.