Skip to content

Commit

Permalink
support both CATEGORY_MASK and CONFIDENCE_MASK output type
Browse files Browse the repository at this point in the history
  • Loading branch information
haruiz committed Mar 26, 2023
1 parent 542f8c0 commit 77a927c
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 43 deletions.
2 changes: 1 addition & 1 deletion codelabs/background_segmenter/code/camera.js
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ export default class Camera {
throw new Error("Error starting the camera: " + e.message);
}
}
async setResolution(width, height){
setResolution(width, height){
/**
* Sets the camera resolution
* @param width {number}
Expand Down
139 changes: 98 additions & 41 deletions codelabs/background_segmenter/code/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import vision from "https://cdn.skypack.dev/@mediapipe/tasks-vision";

// the Camera class will help us to interact with the webcam
import Camera from "./camera.js";
import {downloadImage, scaleImageData} from "./utils.js";
import {downloadImage, resizeImageData} from "./utils.js";
const { ImageSegmenter, SegmentationMask, FilesetResolver } = vision;

// get a reference to the DOM elements that we wil use
Expand All @@ -17,6 +17,8 @@ const txtBackgroundImageInput = document.getElementById('txtBackgroundImage');
// get a reference to the canvas element and its context
const videCanvas = document.getElementById('canvas');
const videoCanvasCtx = videCanvas.getContext('2d');
const segmenterOutputType = "CATEGORY_MASK"
// const segmenterOutputType = "CONFIDENCE_MASK"

// create a new camera instance
const camera = new Camera(video, video.videoWidth, video.videoHeight);
Expand All @@ -27,27 +29,28 @@ let requestAnimationFrameId = null;
let imageSegmenter = null;
let labelsToSegment = ["dog", "person"]
const labels = [
'background',
'aeroplane',
'bicycle',
'bird',
'boat',
'bottle',
'bus',
'car',
'cat',
'chair',
'cow',
'dining table',
'dog',
'horse',
'motorbike',
'person',
'potted plant',
'sheep',
'sofa',
'train',
'tv'];
"background",
"aeroplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"dining table",
"dog",
"horse",
"motorbike",
"person",
"potted plant",
"sheep",
"sofa",
"train",
"tv"
]

// this function will be called when the webpage is loaded
document.addEventListener('DOMContentLoaded', async () =>{
Expand Down Expand Up @@ -75,12 +78,11 @@ async function createImageSegmenter() {
"https://storage.googleapis.com/mediapipe-assets/deeplabv3.tflite?generation=1661875711618421"

},
outputType: segmenterOutputType,
runningMode: "VIDEO",
})
}



// draw the segmentation mask on the canvas
function segmentationCallback(segmentationMask){
/**
Expand Down Expand Up @@ -136,7 +138,7 @@ async function startCamera(){
const deviceId = selVideoSource.value;
const resolution = selVideoResolution.value;
const [width, height] = resolution.split('x');
await camera.setResolution(width, height);
camera.setResolution(width, height);
await camera.start(deviceId);
}

Expand Down Expand Up @@ -206,10 +208,14 @@ btnStop.addEventListener('click', async () => {
clearCanvas();
});

function createSegmentationMaskFromLabels(SegmentationMaskLabels) {
function createSegmentationMaskFromCategoryMask(categoryMask, targetWidth= null, targetHeight = null){
/**
* Creates a segmentation mask from the segmentation mask labels
* @type {Uint8ClampedArray}
* @param SegmentationMaskLabels {Array} each value in this array corresponds
* to the label associated with the pixel at each index in the segmentation mask
* @param targetWidth {number} the target width of the segmentation mask
* @param targetHeight {number} the target height of the segmentation mask
* @returns {ImageData} the segmentation mask
*/

// create Uint8ClampedArray to hold the segmentation mask data
Expand All @@ -218,11 +224,11 @@ function createSegmentationMaskFromLabels(SegmentationMaskLabels) {
);

// we loop through the segmentation mask labels and create the segmentation mask
for (let i in SegmentationMaskLabels) {
for (let i in categoryMask) {

// for each pixel we get the label index from the segmentation mask labels and
// we get the label from the labels array
const labelIdx = SegmentationMaskLabels[i];
const labelIdx = categoryMask[i];
const label = labels[labelIdx];

// we check if the label is in the labelsToSegment array
Expand All @@ -241,13 +247,61 @@ function createSegmentationMaskFromLabels(SegmentationMaskLabels) {
segmentationMask[i * 4 + 3] = 255;
}
}
return segmentationMask;
// we create an ImageData object from the segmentation mask
let segmentationMaskImageData = new ImageData(segmentationMask, video.videoWidth, video.videoHeight);
if(targetWidth && targetHeight){
// if the target width and height are defined, we resize the segmentation mask
segmentationMaskImageData = resizeImageData(segmentationMaskImageData, targetWidth, targetHeight);
}
return segmentationMaskImageData;
}

function createSegmentationMaskFromConfidenceMask(SegmentationMaskLabels, targetWidth= null, targetHeight = null, threshold = 20.0){
/**
* Creates a segmentation mask from the segmentation mask labels
* @param SegmentationMaskLabels {Array} each value in this array corresponds
* to the label associated with the pixel at each index in the segmentation mask
* @param targetWidth {number} the target width of the segmentation mask
* @param targetHeight {number} the target height of the segmentation mask
* @returns {ImageData} the segmentation mask
*/

// create Uint8ClampedArray to hold the segmentation mask data
let segmentationMask = new Uint8ClampedArray(
video.videoWidth * video.videoHeight * 4
);
labelsToSegment.forEach((label, idx) => {
const labelMaskIdx = labels.indexOf(label);
const labelMask = SegmentationMaskLabels[labelMaskIdx];
for (let i in labelMask) {
if (labelMask[i] > threshold) {
segmentationMask[i * 4 + 0] = 255;
segmentationMask[i * 4 + 1] = 255;
segmentationMask[i * 4 + 2] = 255;
segmentationMask[i * 4 + 3] = 255;
} else {
segmentationMask[i * 4 + 0] = 0;
segmentationMask[i * 4 + 1] = 0;
segmentationMask[i * 4 + 2] = 0;
segmentationMask[i * 4 + 3] = 255;
}
}
});
// we create an ImageData object from the segmentation mask
let segmentationMaskImageData = new ImageData(segmentationMask, video.videoWidth, video.videoHeight);
if(targetWidth && targetHeight){
// if the target width and height are defined, we resize the segmentation mask
segmentationMaskImageData = resizeImageData(segmentationMaskImageData, targetWidth, targetHeight);
}
return segmentationMaskImageData;
}

async function drawSegmentationMask({ 0 : SegmentationMaskLabels} = SegmentationResult){


async function drawSegmentationMask(SegmentationResult){

/**
* Draws the segmentation mask on the canvas
* Draws the segmentation results on the canvas
* @param SegmentationMaskLabels {Array} the segmentation mask labels
*/

Expand All @@ -257,12 +311,12 @@ async function drawSegmentationMask({ 0 : SegmentationMaskLabels} = Segmentation
const videoWidth = video.videoWidth;
const videoHeight = video.videoHeight;

// calculate the scale
// calculate the scale of the video to fit the canvas
const scaleX = canvasWidth / videoWidth;
const scaleY = canvasHeight / videoHeight;
const scale = Math.min(scaleX, scaleY);

// Scale the video to fit the canvas.
// The scale is defined for the video width and height
const scaledWidth = videoWidth * scale;
const scaledHeight = videoHeight * scale;

Expand All @@ -272,13 +326,15 @@ async function drawSegmentationMask({ 0 : SegmentationMaskLabels} = Segmentation


// create the segmentation mask from the segmentation mask labels
let segmentationMask = createSegmentationMaskFromLabels(SegmentationMaskLabels);
// create an ImageData object from the segmentation mask
let segmentationMaskImageData = new ImageData(segmentationMask, video.videoWidth, video.videoHeight);
// scale the segmentation mask to the video canvas size
segmentationMaskImageData = scaleImageData(segmentationMaskImageData, scaledWidth, scaledHeight);
let segmentationMask = null;
if(segmenterOutputType === "CATEGORY_MASK"){
segmentationMask = createSegmentationMaskFromCategoryMask(SegmentationResult[0], scaledWidth, scaledHeight);
}
else {
segmentationMask = createSegmentationMaskFromConfidenceMask(SegmentationResult, scaledWidth, scaledHeight);
}
// create an ImageBitmap from the scaled segmentation mask
const segmentationMaskBitmap = await createImageBitmap(segmentationMaskImageData);
const segmentationMaskBitmap = await createImageBitmap(segmentationMask);
// create a canvas to hold the scaled segmentation mask and mirror it
const canvasMask = document.createElement('canvas');
const canvasMaskCtx = canvasMask.getContext('2d');
Expand Down Expand Up @@ -321,7 +377,7 @@ async function drawSegmentationMask({ 0 : SegmentationMaskLabels} = Segmentation
}
else {
// we get the background image data
const backgroundData = scaleImageData(backgroundImage, scaledWidth, scaledHeight).data; // background image data
const backgroundData = resizeImageData(backgroundImage, scaledWidth, scaledHeight).data; // background image data
for (let i = 0; i < canvasData.length; i += 4) {
// we check if the pixel is a background pixel or not and we set the pixel to the background image pixel if it is a background pixel
const isBackgroundPixel = binaryMaskData[i] === 0 && binaryMaskData[i + 1] === 0 && binaryMaskData[i + 2] === 0;
Expand All @@ -340,3 +396,4 @@ async function drawSegmentationMask({ 0 : SegmentationMaskLabels} = Segmentation




2 changes: 1 addition & 1 deletion codelabs/background_segmenter/code/utils.js
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export function scaleImageData(imageData, newWidth, newHeight) {
export function resizeImageData(imageData, newWidth, newHeight) {
/**
* Scales the given image data to the given width and height
* @param imageData {ImageData} The image data to scale
Expand Down
Empty file added codelabs/rigged_hand/utils.js
Empty file.

0 comments on commit 77a927c

Please sign in to comment.