Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions omniparser-node/captioning.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import {
Florence2ForConditionalGeneration,
AutoProcessor,
AutoTokenizer,
} from "@huggingface/transformers";

export class Caption {
/**
* Create a new Caption model.
* @param {import('@huggingface/transformers').PreTrainedModel} model The model to use for captioning
* @param {import('@huggingface/transformers').Processor} processor The processor to use for captioning
* @param {import('@huggingface/transformers').PreTrainedTokenizer} tokenizer The tokenizer to use for captioning
*/
constructor(model, processor, tokenizer) {
this.model = model;
this.processor = processor;
this.tokenizer = tokenizer;

// Prepare text inputs
this.task = "<CAPTION>";
const prompts = processor.construct_prompts(this.task);
this.text_inputs = tokenizer(prompts);
}

/**
* Generate a caption for an image.
* @param {import('@huggingface/transformers').RawImage} image The input image.
* @returns {Promise<string>} The caption for the image
*/
async describe(image) {
const vision_inputs = await this.processor(image);

// Generate text
const generated_ids = await this.model.generate({
...this.text_inputs,
...vision_inputs,
max_new_tokens: 256,
});

// Decode generated text
const generated_text = this.tokenizer.batch_decode(generated_ids, {
skip_special_tokens: false,
})[0];

// Post-process the generated text
const result = this.processor.post_process_generation(
generated_text,
this.task,
image.size,
);
return result[this.task];
}

static async from_pretrained(model_id) {
const model = await Florence2ForConditionalGeneration.from_pretrained(
model_id,
{ dtype: "fp32" },
);
const processor = await AutoProcessor.from_pretrained(model_id);
const tokenizer = await AutoTokenizer.from_pretrained(model_id);

return new Caption(model, processor, tokenizer);
}
}
112 changes: 112 additions & 0 deletions omniparser-node/detector.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import { AutoModel, AutoProcessor, RawImage } from "@huggingface/transformers";

/**
* @typedef {Object} Detection
* @property {number} x1 The x-coordinate of the top-left corner.
* @property {number} y1 The y-coordinate of the top-left corner.
* @property {number} x2 The x-coordinate of the bottom-right corner.
* @property {number} y2 The y-coordinate of the bottom-right corner.
* @property {number} score The confidence score of the detection.
*/

/**
* Compute Intersection over Union (IoU) between two detections.
* @param {Detection} a The first detection.
* @param {Detection} b The second detection.
*/
function iou(a, b) {
const x1 = Math.max(a.x1, b.x1);
const y1 = Math.max(a.y1, b.y1);
const x2 = Math.min(a.x2, b.x2);
const y2 = Math.min(a.y2, b.y2);

const intersection = Math.max(0, x2 - x1) * Math.max(0, y2 - y1);
const area1 = (a.x2 - a.x1) * (a.y2 - a.y1);
const area2 = (b.x2 - b.x1) * (b.y2 - b.y1);
const union = area1 + area2 - intersection;

return intersection / union;
}

/**
* Run Non-Maximum Suppression (NMS) on a list of detections.
* @param {Detection[]} detections The list of detections.
* @param {number} iouThreshold The IoU threshold for NMS.
*/
export function nms(detections, iouThreshold) {
const result = [];
while (detections.length > 0) {
const best = detections.reduce((acc, detection) =>
detection.score > acc.score ? detection : acc,
);
result.push(best);
detections = detections.filter(
(detection) => iou(detection, best) < iouThreshold,
);
}
return result;
}

export class Detector {
/**
* Create a new YOLOv8 detector.
* @param {import('@huggingface/transformers').PreTrainedModel} model The model to use for detection
* @param {import('@huggingface/transformers').Processor} processor The processor to use for detection
*/
constructor(model, processor) {
this.model = model;
this.processor = processor;
}

/**
* Run detection on an image.
* @param {RawImage|string|URL} input The input image.
* @param {Object} [options] The options for detection.
* @param {number} [options.confidence_threshold=0.25] The confidence threshold.
* @param {number} [options.iou_threshold=0.7] The IoU threshold for NMS.
* @returns {Promise<Detection[]>} The list of detections
*/
async predict(
input,
{ confidence_threshold = 0.25, iou_threshold = 0.7 } = {},
) {
const image = await RawImage.read(input);
const { pixel_values } = await this.processor(image);

// Run detection
const { output0 } = await this.model({ images: pixel_values });

// Post-process output
const permuted = output0[0].transpose(1, 0);
// `permuted` is a Tensor of shape [ 5460, 5 ]:
// - 5460 potential bounding boxes
// - 5 parameters for each box:
// - first 4 are coordinates for the bounding boxes (x-center, y-center, width, height)
// - the last one is the confidence score

// Format output
const result = [];
const [scaledHeight, scaledWidth] = pixel_values.dims.slice(-2);
for (const [xc, yc, w, h, score] of permuted.tolist()) {
// Filter if not confident enough
if (score < confidence_threshold) continue;

// Get pixel values, taking into account the original image size
const x1 = ((xc - w / 2) / scaledWidth) * image.width;
const y1 = ((yc - h / 2) / scaledHeight) * image.height;
const x2 = ((xc + w / 2) / scaledWidth) * image.width;
const y2 = ((yc + h / 2) / scaledHeight) * image.height;

// Add to result
result.push({ x1, x2, y1, y2, score });
}

return nms(result, iou_threshold);
}

static async from_pretrained(model_id) {
const model = await AutoModel.from_pretrained(model_id, { dtype: "fp32" });
const processor = await AutoProcessor.from_pretrained(model_id);
return new Detector(model, processor);
}
}
32 changes: 32 additions & 0 deletions omniparser-node/index.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import { RawImage } from "@huggingface/transformers";
import { Detector } from "./detector.js";
import { Caption } from "./captioning.js";

// Load detection model
const detector_model_id = "onnx-community/OmniParser-icon_detect";
const detector = await Detector.from_pretrained(detector_model_id);

// Load captioning model
const captioning_model_id = "onnx-community/Florence-2-base-ft";
const captioning = await Caption.from_pretrained(captioning_model_id);

// Read image from URL
const url =
"https://raw.githubusercontent.com/microsoft/OmniParser/refs/heads/master/imgs/google_page.png";
const image = await RawImage.read(url);

// Run detection
const detections = await detector.predict(image, {
confidence_threshold: 0.05,
iou_threshold: 0.7,
});

for (const { x1, x2, y1, y2, score } of detections) {
// Crop image
const bbox = [x1, y1, x2, y2].map(Math.round);
const cropped_image = await image.crop(bbox);

// Run captioning
const text = await captioning.describe(cropped_image);
console.log({ text, bbox, score });
}
Loading