Skip to content

Commit

Permalink
[examples] Adds segment anything 2 example (#3449)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu authored Sep 3, 2024
1 parent 50e2770 commit 555c596
Show file tree
Hide file tree
Showing 18 changed files with 606 additions and 34 deletions.
51 changes: 34 additions & 17 deletions api/src/main/java/ai/djl/modality/cv/BufferedImageFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ public List<BoundingBox> findBoundingBoxes() {

/** {@inheritDoc} */
@Override
public void drawBoundingBoxes(DetectedObjects detections) {
public void drawBoundingBoxes(DetectedObjects detections, float opacity) {
// Make image copy with alpha channel because original image was jpg
convertIdNeeded();

Expand All @@ -321,25 +321,40 @@ public void drawBoundingBoxes(DetectedObjects detections) {
k = (k + 100) % 255;
}

Rectangle rectangle = box.getBounds();
int x = (int) (rectangle.getX() * imageWidth);
int y = (int) (rectangle.getY() * imageHeight);
g.drawRect(
x,
y,
(int) (rectangle.getWidth() * imageWidth),
(int) (rectangle.getHeight() * imageHeight));
drawText(g, className, x, y, stroke, 4);
if (!className.isEmpty()) {
Rectangle rectangle = box.getBounds();
int x = (int) (rectangle.getX() * imageWidth);
int y = (int) (rectangle.getY() * imageHeight);
g.drawRect(
x,
y,
(int) (rectangle.getWidth() * imageWidth),
(int) (rectangle.getHeight() * imageHeight));
drawText(g, className, x, y, stroke, 4);
}
// If we have a mask instead of a plain rectangle, draw tha mask
if (box instanceof Mask) {
drawMask((Mask) box);
drawMask((Mask) box, opacity);
} else if (box instanceof Landmark) {
drawLandmarks(box);
}
}
g.dispose();
}

/** {@inheritDoc} */
@Override
public void drawMarks(List<Point> points, int radius) {
Graphics2D g = (Graphics2D) image.getGraphics();
g.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON);
g.setColor(new Color(246, 96, 0));
for (Point point : points) {
int[][] star = createStar(point, radius);
g.fillPolygon(star[0], star[1], 10);
}
g.dispose();
}

/** {@inheritDoc} */
@Override
public void drawJoints(Joints joints) {
Expand Down Expand Up @@ -421,7 +436,7 @@ private void drawText(Graphics2D g, String text, int x, int y, int stroke, int p
g.drawString(text, x + padding, y + ascent);
}

private void drawMask(Mask mask) {
private void drawMask(Mask mask, float ratio) {
float r = RandomUtils.nextFloat();
float g = RandomUtils.nextFloat();
float b = RandomUtils.nextFloat();
Expand All @@ -445,13 +460,15 @@ private void drawMask(Mask mask) {
}
}
float[][] probDist = mask.getProbDist();
float max = 0;
for (float[] row : probDist) {
for (float f : row) {
max = Math.max(max, f);
if (ratio < 0 || ratio > 1) {
float max = 0;
for (float[] row : probDist) {
for (float f : row) {
max = Math.max(max, f);
}
}
ratio = 0.5f / max;
}
float ratio = 0.5f / max;

BufferedImage maskImage =
new BufferedImage(
Expand Down
58 changes: 57 additions & 1 deletion api/src/main/java/ai/djl/modality/cv/Image.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import ai.djl.modality.cv.output.BoundingBox;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Joints;
import ai.djl.modality.cv.output.Point;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;

Expand Down Expand Up @@ -125,7 +126,36 @@ default NDArray toNDArray(NDManager manager) {
*
* @param detections the object detection results
*/
void drawBoundingBoxes(DetectedObjects detections);
default void drawBoundingBoxes(DetectedObjects detections) {
drawBoundingBoxes(detections, -1);
}

/**
* Draws the bounding boxes on the image.
*
* @param detections the object detection results
*/
void drawBoundingBoxes(DetectedObjects detections, float opacity);

/**
* Draws a mark on the image.
*
* @param points as list of {@code Point}
*/
default void drawMarks(List<Point> points) {
int w = getWidth();
int h = getHeight();
int size = Math.min(w, h) / 50;
drawMarks(points, size);
}

/**
* Draws a mark on the image.
*
* @param points as list of {@code Point}
* @param size the radius of the star mark
*/
void drawMarks(List<Point> points, int size);

/**
* Draws all joints of a body on an image.
Expand All @@ -142,6 +172,32 @@ default NDArray toNDArray(NDManager manager) {
*/
void drawImage(Image overlay, boolean resize);

/**
* Creates a star shape.
*
* @param point the coordinate
* @param radius the radius
* @return the polygon points
*/
default int[][] createStar(Point point, int radius) {
int[][] ret = new int[2][10];
double midX = point.getX();
double midY = point.getY();
double[] ratio = {radius, radius * 0.38196601125};

double delta = Math.PI / 5;
for (int i = 0; i < 10; ++i) {
double angle = i * delta;
double r = ratio[i % 2];
double x = Math.cos(angle) * r;
double y = Math.sin(angle) * r;

ret[0][i] = (int) (x + midX);
ret[1][i] = (int) (y + midY);
}
return ret;
}

/** Flag indicates the color channel options for images. */
enum Flag {
GRAYSCALE,
Expand Down
21 changes: 21 additions & 0 deletions api/src/main/java/ai/djl/modality/cv/output/Mask.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
*/
package ai.djl.modality.cv.output;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.types.Shape;

/**
* A mask with a probability for each pixel within a bounding rectangle.
*
Expand Down Expand Up @@ -75,4 +78,22 @@ public float[][] getProbDist() {
public boolean isFullImageMask() {
return fullImageMask;
}

/**
* Converts the mask tensor to a mask array.
*
* @param array the mask NDArray
* @return the mask array
*/
public static float[][] toMask(NDArray array) {
Shape maskShape = array.getShape();
int height = (int) maskShape.get(0);
int width = (int) maskShape.get(1);
float[] flattened = array.toFloatArray();
float[][] mask = new float[height][width];
for (int i = 0; i < height; i++) {
System.arraycopy(flattened, i * width, mask[i], 0, width);
}
return mask;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,7 @@ public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {

// Reshape mask to actual image bounding box shape.
NDArray array = masks.get(i);
Shape maskShape = array.getShape();
int maskH = (int) maskShape.get(0);
int maskW = (int) maskShape.get(1);
float[] flattened = array.toFloatArray();
float[][] maskFloat = new float[maskH][maskW];
for (int j = 0; j < maskH; j++) {
System.arraycopy(flattened, j * maskW, maskFloat[j], 0, maskW);
}
float[][] maskFloat = Mask.toMask(array);
Mask mask = new Mask(x, y, w, h, maskFloat);

retNames.add(className);
Expand Down
Loading

0 comments on commit 555c596

Please sign in to comment.