Skip to content

Commit

Permalink
[api] Adds center fit image operation for Yolo
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Aug 17, 2024
1 parent ede8264 commit b87cc7f
Show file tree
Hide file tree
Showing 9 changed files with 119 additions and 84 deletions.
59 changes: 59 additions & 0 deletions api/src/main/java/ai/djl/modality/cv/transform/CenterFit.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.modality.cv.transform;

import ai.djl.modality.cv.util.NDImageUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.Transform;

/** A {@link Transform} that fit the size of an image. */
public class CenterFit implements Transform {

private int width;
private int height;

/**
* Creates a {@code CenterFit} {@link Transform} that fit to the given width and height with
* given interpolation.
*
* @param width the desired width
* @param height the desired height
*/
public CenterFit(int width, int height) {
this.width = width;
this.height = height;
}

/** {@inheritDoc} */
@Override
public NDArray transform(NDArray array) {
Shape shape = array.getShape();
int w = (int) shape.get(1);
int h = (int) shape.get(0);
if (w > width || h > height) {
array = NDImageUtils.centerCrop(array, Math.min(w, width), Math.min(h, height));
}
int padW = width - w;
int padH = height - h;
if (padW > 0 || padH > 0) {
padW = Math.max(0, padW);
padH = Math.max(0, padH);
int padW1 = padW / 2;
int padH1 = padH / 2;
Shape padding = new Shape(0, 0, padW1, padW - padW1, padH1, padH - padH1);
array = array.pad(padding, 0);
}
return array;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import ai.djl.Model;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.transform.CenterCrop;
import ai.djl.modality.cv.transform.CenterFit;
import ai.djl.modality.cv.transform.Normalize;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
Expand Down Expand Up @@ -50,6 +51,8 @@ public abstract class BaseImageTranslator<T> implements Translator<Image, T> {

private Image.Flag flag;
private Batchifier batchifier;
protected int width;
protected int height;

/**
* Constructs an ImageTranslator with the provided builder.
Expand All @@ -60,6 +63,8 @@ public BaseImageTranslator(BaseBuilder<?> builder) {
flag = builder.flag;
pipeline = builder.pipeline;
batchifier = builder.batchifier;
width = builder.width;
height = builder.height;
}

/** {@inheritDoc} */
Expand All @@ -72,6 +77,8 @@ public Batchifier getBatchifier() {
@Override
public NDList processInput(TranslatorContext ctx, Image input) {
NDArray array = input.toNDArray(ctx.getNDManager(), flag);
ctx.setAttachment("width", input.getWidth());
ctx.setAttachment("height", input.getHeight());
return pipeline.transform(new NDList(array));
}

Expand Down Expand Up @@ -171,6 +178,10 @@ protected void configPreProcess(Map<String, ?> arguments) {
if (ArgumentsUtil.booleanValue(arguments, "centerCrop", false)) {
addTransform(new CenterCrop(width, height));
}
String centerFit = ArgumentsUtil.stringValue(arguments, "centerFit", "false");
if ("true".equals(centerFit)) {
addTransform(new CenterFit(width, height));
}
if (ArgumentsUtil.booleanValue(arguments, "toTensor", true)) {
addTransform(new ToTensor());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@ public abstract class ObjectDetectionTranslator extends BaseImageTranslator<Dete
protected float threshold;
private SynsetLoader synsetLoader;
protected List<String> classes;
protected double imageWidth;
protected double imageHeight;
protected boolean applyRatio;
protected boolean removePadding;

/**
* Creates the {@link ObjectDetectionTranslator} from the given builder.
Expand All @@ -42,9 +41,8 @@ protected ObjectDetectionTranslator(ObjectDetectionBuilder<?> builder) {
super(builder);
this.threshold = builder.threshold;
this.synsetLoader = builder.synsetLoader;
this.imageWidth = builder.imageWidth;
this.imageHeight = builder.imageHeight;
this.applyRatio = builder.applyRatio;
this.removePadding = builder.removePadding;
}

/** {@inheritDoc} */
Expand All @@ -61,9 +59,8 @@ public abstract static class ObjectDetectionBuilder<T extends ObjectDetectionBui
extends ClassificationBuilder<T> {

protected float threshold = 0.2f;
protected double imageWidth;
protected double imageHeight;
protected boolean applyRatio;
protected boolean removePadding;

/**
* Sets the threshold for prediction accuracy.
Expand All @@ -78,19 +75,6 @@ public T optThreshold(float threshold) {
return self();
}

/**
* Sets the optional rescale size.
*
* @param imageWidth the width to rescale images to
* @param imageHeight the height to rescale images to
* @return this builder
*/
public T optRescaleSize(double imageWidth, double imageHeight) {
this.imageWidth = imageWidth;
this.imageHeight = imageHeight;
return self();
}

/**
* Determine Whether to divide output object width/height on the inference result. Default
* false.
Expand All @@ -108,37 +92,17 @@ public T optApplyRatio(boolean value) {
return self();
}

/**
* Get resized image width.
*
* @return image width
*/
public double getImageWidth() {
return imageWidth;
}

/**
* Get resized image height.
*
* @return image height
*/
public double getImageHeight() {
return imageHeight;
}

/** {@inheritDoc} */
@Override
protected void configPostProcess(Map<String, ?> arguments) {
super.configPostProcess(arguments);
if (ArgumentsUtil.booleanValue(arguments, "rescale")) {
optRescaleSize(width, height);
}
if (ArgumentsUtil.booleanValue(arguments, "optApplyRatio")
|| ArgumentsUtil.booleanValue(arguments, "applyRatio")) {
optApplyRatio(true);
optRescaleSize(width, height);
}
threshold = ArgumentsUtil.floatValue(arguments, "threshold", 0.2f);
String centerFit = ArgumentsUtil.stringValue(arguments, "centerFit", "false");
removePadding = "true".equals(centerFit);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,14 @@ public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {
}
String className = classes.get(classId);
float[] box = boundingBoxes.get(i).toFloatArray();
// rescale box coordinates by imageWidth and imageHeight
double x = imageWidth > 0 ? box[0] / imageWidth : box[0];
double y = imageHeight > 0 ? box[1] / imageHeight : box[1];
double w = imageWidth > 0 ? box[2] / imageWidth - x : box[2] - x;
double h = imageHeight > 0 ? box[3] / imageHeight - y : box[3] - y;
// rescale box coordinates by width and height
double x = width > 0 ? box[0] / width : box[0];
double y = height > 0 ? box[1] / height : box[1];
double w = width > 0 ? box[2] / width - x : box[2] - x;
double h = height > 0 ? box[3] / height - y : box[3] - y;
Rectangle rect;
if (applyRatio) {
rect =
new Rectangle(
x / imageWidth,
y / imageHeight,
w / imageWidth,
h / imageHeight);
rect = new Rectangle(x / width, y / height, w / width, h / height);
} else {
rect = new Rectangle(x, y, w, h);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ public class YoloPoseTranslator extends BaseImageTranslator<Joints[]> {

private static final int MAX_DETECTION = 300;

private int width;
private int height;
private float threshold;
private float nmsThreshold;

Expand All @@ -41,8 +39,6 @@ public class YoloPoseTranslator extends BaseImageTranslator<Joints[]> {
*/
public YoloPoseTranslator(Builder builder) {
super(builder);
this.width = builder.width;
this.height = builder.height;
this.threshold = builder.threshold;
this.nmsThreshold = builder.nmsThreshold;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ public class YoloSegmentationTranslator extends YoloV5Translator {

private float threshold;
private float nmsThreshold;
private int width;
private int height;

/**
* Creates the instance segmentation translator from the given builder.
Expand All @@ -45,8 +43,6 @@ public YoloSegmentationTranslator(Builder builder) {
super(builder);
this.threshold = builder.threshold;
this.nmsThreshold = builder.nmsThreshold;
this.width = builder.width;
this.height = builder.height;
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {
NDArray boundingBoxes = list.get(2);
int detected = Math.toIntExact(probs.length);

NDArray xMin = boundingBoxes.get(":, 0").clip(0, imageWidth).div(imageWidth);
NDArray yMin = boundingBoxes.get(":, 1").clip(0, imageHeight).div(imageHeight);
NDArray xMax = boundingBoxes.get(":, 2").clip(0, imageWidth).div(imageWidth);
NDArray yMax = boundingBoxes.get(":, 3").clip(0, imageHeight).div(imageHeight);
NDArray xMin = boundingBoxes.get(":, 0").clip(0, width).div(width);
NDArray yMin = boundingBoxes.get(":, 1").clip(0, height).div(height);
NDArray xMax = boundingBoxes.get(":, 2").clip(0, width).div(width);
NDArray yMax = boundingBoxes.get(":, 3").clip(0, height).div(height);

float[] boxX = xMin.toFloatArray();
float[] boxY = yMin.toFloatArray();
Expand All @@ -67,10 +67,10 @@ public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {
if (applyRatio) {
rect =
new Rectangle(
boxX[i] / imageWidth,
boxY[i] / imageHeight,
boxWidth[i] / imageWidth,
boxHeight[i] / imageHeight);
boxX[i] / width,
boxY[i] / height,
boxWidth[i] / width,
boxHeight[i] / height);
} else {
rect = new Rectangle(boxX[i], boxY[i], boxWidth[i], boxHeight[i]);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

/**
* A translator for YoloV5 models. This was tested with ONNX exported Yolo models. For details check
* here: https://github.com/ultralytics/yolov5
* <a href="https://github.com/ultralytics/yolov5">here</a>
*/
public class YoloV5Translator extends ObjectDetectionTranslator {

Expand Down Expand Up @@ -68,7 +68,11 @@ public static YoloV5Translator.Builder builder(Map<String, ?> arguments) {
}

protected DetectedObjects nms(
List<Rectangle> boxes, List<Integer> classIds, List<Float> scores) {
int imageWidth,
int imageHeight,
List<Rectangle> boxes,
List<Integer> classIds,
List<Float> scores) {
List<String> retClasses = new ArrayList<>();
List<Double> retProbs = new ArrayList<>();
List<BoundingBox> retBB = new ArrayList<>();
Expand All @@ -94,22 +98,30 @@ protected DetectedObjects nms(
retClasses.add(classes.get(id));
retProbs.add(scores.get(pos).doubleValue());
Rectangle rect = boxes.get(pos);
if (applyRatio) {
retBB.add(
if (removePadding) {
int padW = (width - imageWidth) / 2;
int padH = (height - imageHeight) / 2;
rect =
new Rectangle(
rect.getX() / imageWidth,
rect.getY() / imageHeight,
(rect.getX() - padW) / imageWidth,
(rect.getY() - padH) / imageHeight,
rect.getWidth() / imageWidth,
rect.getHeight() / imageHeight));
} else {
retBB.add(rect);
rect.getHeight() / imageHeight);
} else if (applyRatio) {
rect =
new Rectangle(
rect.getX() / width,
rect.getY() / height,
rect.getWidth() / width,
rect.getHeight() / height);
}
retBB.add(rect);
}
}
return new DetectedObjects(retClasses, retProbs, retBB);
}

protected DetectedObjects processFromBoxOutput(NDList list) {
protected DetectedObjects processFromBoxOutput(int imageWidth, int imageHeight, NDList list) {
float[] flattened = list.get(0).toFloatArray();
int sizeClasses = classes.size();
int stride = 5 + sizeClasses;
Expand Down Expand Up @@ -142,7 +154,7 @@ protected DetectedObjects processFromBoxOutput(NDList list) {
classIds.add(maxIndex);
}
}
return nms(boxes, classIds, scores);
return nms(imageWidth, imageHeight, boxes, classIds, scores);
}

private DetectedObjects processFromDetectOutput() {
Expand All @@ -153,18 +165,20 @@ private DetectedObjects processFromDetectOutput() {
/** {@inheritDoc} */
@Override
public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {
int imageWidth = (Integer) ctx.getAttachment("width");
int imageHeight = (Integer) ctx.getAttachment("height");
switch (yoloOutputLayerType) {
case DETECT:
return processFromDetectOutput();
case AUTO:
if (list.get(0).getShape().dimension() > 2) {
return processFromDetectOutput();
} else {
return processFromBoxOutput(list);
return processFromBoxOutput(imageWidth, imageHeight, list);
}
case BOX:
default:
return processFromBoxOutput(list);
return processFromBoxOutput(imageWidth, imageHeight, list);
}
}

Expand Down
Loading

0 comments on commit b87cc7f

Please sign in to comment.