diff --git a/image/yolo.ipynb b/image/yolo.ipynb
new file mode 100644
index 0000000..b9b99cb
--- /dev/null
+++ b/image/yolo.ipynb
@@ -0,0 +1,445 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "53dc444c",
+ "metadata": {},
+ "source": [
+ "## Run with YOLO\n",
+ "\n",
+ "- yolov5: https://pytorch.org/hub/ultralytics_yolov5/\n",
+ "- yolov8: https://github.com/ultralytics/ultralytics, https://docs.ultralytics.com/modes/predict/"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "4b5b2ce6",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "YOLOv5 🚀 2022-12-28 Python-3.8.12 torch-1.13.1 CPU\n",
+ "\n",
+ "Fusing layers... \n",
+ "YOLOv5s summary: 213 layers, 7225885 parameters, 0 gradients\n",
+ "Adding AutoShape... \n"
+ ]
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "model1 = torch.hub.load(\"ultralytics/yolov5\", \"yolov5s\", pretrained=True, verbose=False)\n",
+ "\n",
+ "\n",
+ "from ultralytics import YOLO\n",
+ "model2 = YOLO(\"yolov8n.pt\") "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "5e163003",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Ultralytics YOLOv8.0.48 🚀 Python-3.8.12 torch-1.13.1 CPU\n",
+ "YOLOv8n summary (fused): 168 layers, 3151904 parameters, 0 gradients, 8.7 GFLOPs\n",
+ "\n",
+ "Found https://ultralytics.com/images/bus.jpg locally at bus.jpg\n",
+ "image 1/1 /Users/chenshiyu/workspace/git/towhee/jupyter/bus.jpg: 640x480 4 persons, 1 bus, 1 stop sign, 67.5ms\n",
+ "Speed: 1.5ms preprocess, 67.5ms inference, 1.0ms postprocess per image at shape (1, 3, 640, 640)\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[[17, 231, 802, 768], [49, 399, 245, 903], [670, 380, 810, 876], [221, 406, 345, 857], [0, 255, 32, 325], [0, 551, 67, 874]] ['bus', 'person', 'person', 'person', 'stop sign', 'person'] [0.8705446124076843, 0.8689801692962646, 0.8536036610603333, 0.8193051218986511, 0.34606924653053284, 0.301294207572937]\n"
+ ]
+ }
+ ],
+ "source": [
+ "results = model2(\"https://ultralytics.com/images/bus.jpg\")\n",
+ "\n",
+ "result = results[0]\n",
+ "boxes = [list(map(int, xyxy)) for xyxy in result.boxes.xyxy]\n",
+ "classes = [result.names[int(i)] for i in result.boxes.cls]\n",
+ "scores = result.boxes.conf.tolist()\n",
+ "print(boxes, classes, scores)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "2ddf7d2c",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[[671, 395, 810, 878], [220, 408, 346, 867], [49, 389, 247, 912], [12, 223, 809, 789], [0, 552, 67, 875]] ['person', 'person', 'person', 'bus', 'person'] [0.8966755867004395, 0.8693942427635193, 0.850602388381958, 0.8504517078399658, 0.5373987555503845]\n"
+ ]
+ }
+ ],
+ "source": [
+ "results = model1(\"https://ultralytics.com/images/bus.jpg\")\n",
+ "boxes = [re[0:4] for re in results.xyxy[0]]\n",
+ "boxes = [list(map(int, box)) for box in boxes]\n",
+ "classes = list(results.pandas().xyxy[0].name)\n",
+ "scores = list(results.pandas().xyxy[0].confidence)\n",
+ "print(boxes, classes, scores)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "dc06c559",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3c6874eb",
+ "metadata": {},
+ "source": [
+ "## Develop yolo operator\n",
+ "\n",
+ "- yolov5 op: https://towhee.io/object-detection/yolov5\n",
+ "- yolo op which can specific the model, suchas the model(https://github.com/ultralytics/assets/releases)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "5ac79368",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Using cache found in /Users/chenshiyu/.cache/torch/hub/ultralytics_yolov5_master\n"
+ ]
+ }
+ ],
+ "source": [
+ "from towhee import pipe, ops, DataCollection\n",
+ "\n",
+ "p = (\n",
+ " pipe.input('path')\n",
+ " .map('path', 'img', ops.image_decode.cv2_rgb())\n",
+ " .map('img', ('box', 'class', 'score'), ops.object_detection.yolov5())\n",
+ " .map(('img', 'box'), 'object', ops.image_crop(clamp=True))\n",
+ " .output('img', 'object', 'class')\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "c86c59d6",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
img | object | class |
\n",
+ " | | person person person bus person |
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "DataCollection(p(\"https://ultralytics.com/images/bus.jpg\")).show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "a336ab8a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from towhee import pipe, ops, DataCollection\n",
+ "\n",
+ "p = (\n",
+ " pipe.input('path')\n",
+ " .map('path', 'img', ops.image_decode.cv2_rgb())\n",
+ " .map('img', ('box', 'class', 'score'), ops.object_detection.yolo(model=\"yolov8n.pt\"))\n",
+ "# .map('img', ('box', 'class', 'score'), ops.object_detection.yolov5())\n",
+ " .map(('img', 'box'), 'object', ops.image_crop(clamp=True))\n",
+ " .output('img', 'object', 'class')\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "442ebf2d",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Ultralytics YOLOv8.0.48 🚀 Python-3.8.12 torch-1.13.1 CPU\n",
+ "YOLOv8n summary (fused): 168 layers, 3151904 parameters, 0 gradients, 8.7 GFLOPs\n",
+ "\n",
+ "0: 640x480 4 persons, 1 bus, 65.2ms\n",
+ "Speed: 1.1ms preprocess, 65.2ms inference, 0.6ms postprocess per image at shape (1, 3, 640, 640)\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "img | object | class |
\n",
+ " | | person person bus person person |
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "DataCollection(p(\"https://ultralytics.com/images/bus.jpg\")).show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "07badb31",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2943fa74",
+ "metadata": {},
+ "source": [
+ "## Set for training\n",
+ "\n",
+ "> `makir dataset` in '..'\n",
+ "\n",
+ "- yolo train: https://github.com/ultralytics/ultralytics/blob/dce4efce48a05e028e6ec430045431c242e52484/docs/yolov5/tutorials/train_custom_data.md"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "9694e1e1",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import towhee\n",
+ "\n",
+ "op = towhee.ops.object_detection.yolo().get_op()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "340c963c",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "New https://pypi.org/project/ultralytics/8.0.100 available 😃 Update with 'pip install -U ultralytics'\n",
+ "Ultralytics YOLOv8.0.48 🚀 Python-3.8.12 torch-1.13.1 CPU\n",
+ "\u001b[34m\u001b[1myolo/engine/trainer: \u001b[0mtask=detect, mode=train, model=yolov8n.pt, data=coco128.yaml, epochs=3, patience=50, batch=16, imgsz=640, save=True, save_period=-1, cache=False, device=None, workers=8, project=None, name=None, exist_ok=False, pretrained=False, optimizer=SGD, verbose=True, seed=0, deterministic=True, single_cls=False, image_weights=False, rect=False, cos_lr=False, close_mosaic=10, resume=False, min_memory=False, overlap_mask=True, mask_ratio=4, dropout=0.0, val=True, split=val, save_json=False, save_hybrid=False, conf=None, iou=0.7, max_det=300, half=False, dnn=False, plots=True, source=None, show=False, save_txt=False, save_conf=False, save_crop=False, hide_labels=False, hide_conf=False, vid_stride=1, line_thickness=3, visualize=False, augment=False, agnostic_nms=False, classes=None, retina_masks=False, boxes=True, format=torchscript, keras=False, optimize=False, int8=False, dynamic=False, simplify=False, opset=None, workspace=4, nms=False, lr0=0.01, lrf=0.01, momentum=0.937, weight_decay=0.0005, warmup_epochs=3.0, warmup_momentum=0.8, warmup_bias_lr=0.1, box=7.5, cls=0.5, dfl=1.5, fl_gamma=0.0, label_smoothing=0.0, nbs=64, hsv_h=0.015, hsv_s=0.7, hsv_v=0.4, degrees=0.0, translate=0.1, scale=0.5, shear=0.0, perspective=0.0, flipud=0.0, fliplr=0.5, mosaic=1.0, mixup=0.0, copy_paste=0.0, cfg=None, v5loader=False, tracker=botsort.yaml, save_dir=runs/detect/train4\n",
+ "\n",
+ " from n params module arguments \n",
+ " 0 -1 1 464 ultralytics.nn.modules.Conv [3, 16, 3, 2] \n",
+ " 1 -1 1 4672 ultralytics.nn.modules.Conv [16, 32, 3, 2] \n",
+ " 2 -1 1 7360 ultralytics.nn.modules.C2f [32, 32, 1, True] \n",
+ " 3 -1 1 18560 ultralytics.nn.modules.Conv [32, 64, 3, 2] \n",
+ " 4 -1 2 49664 ultralytics.nn.modules.C2f [64, 64, 2, True] \n",
+ " 5 -1 1 73984 ultralytics.nn.modules.Conv [64, 128, 3, 2] \n",
+ " 6 -1 2 197632 ultralytics.nn.modules.C2f [128, 128, 2, True] \n",
+ " 7 -1 1 295424 ultralytics.nn.modules.Conv [128, 256, 3, 2] \n",
+ " 8 -1 1 460288 ultralytics.nn.modules.C2f [256, 256, 1, True] \n",
+ " 9 -1 1 164608 ultralytics.nn.modules.SPPF [256, 256, 5] \n",
+ " 10 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] \n",
+ " 11 [-1, 6] 1 0 ultralytics.nn.modules.Concat [1] \n",
+ " 12 -1 1 148224 ultralytics.nn.modules.C2f [384, 128, 1] \n",
+ " 13 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] \n",
+ " 14 [-1, 4] 1 0 ultralytics.nn.modules.Concat [1] \n",
+ " 15 -1 1 37248 ultralytics.nn.modules.C2f [192, 64, 1] \n",
+ " 16 -1 1 36992 ultralytics.nn.modules.Conv [64, 64, 3, 2] \n",
+ " 17 [-1, 12] 1 0 ultralytics.nn.modules.Concat [1] \n",
+ " 18 -1 1 123648 ultralytics.nn.modules.C2f [192, 128, 1] \n",
+ " 19 -1 1 147712 ultralytics.nn.modules.Conv [128, 128, 3, 2] \n",
+ " 20 [-1, 9] 1 0 ultralytics.nn.modules.Concat [1] \n",
+ " 21 -1 1 493056 ultralytics.nn.modules.C2f [384, 256, 1] \n",
+ " 22 [15, 18, 21] 1 897664 ultralytics.nn.modules.Detect [80, [64, 128, 256]] \n",
+ "Model summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs\n",
+ "\n",
+ "Transferred 355/355 items from pretrained weights\n",
+ "\u001b[34m\u001b[1moptimizer:\u001b[0m SGD(lr=0.01) with parameter groups 57 weight(decay=0.0), 64 weight(decay=0.0005), 63 bias\n",
+ "\u001b[34m\u001b[1mtrain: \u001b[0mScanning /Users/chenshiyu/workspace/git/towhee/jupyter/datasets/coco128/labels/train2017.cache... 126 images, 2 backgrounds, 0 corrupt: 100%|██████████| 128/128 [00:00, ?it/s]\u001b[0m\n",
+ "\u001b[34m\u001b[1mval: \u001b[0mScanning /Users/chenshiyu/workspace/git/towhee/jupyter/datasets/coco128/labels/train2017.cache... 126 images, 2 backgrounds, 0 corrupt: 100%|██████████| 128/128 [00:00, ?it/s]\u001b[0m\n",
+ "Plotting labels to runs/detect/train4/labels.jpg... \n",
+ "Image sizes 640 train, 640 val\n",
+ "Using 0 dataloader workers\n",
+ "Logging results to \u001b[1mruns/detect/train4\u001b[0m\n",
+ "Starting training for 3 epochs...\n",
+ "\n",
+ " Epoch GPU_mem box_loss cls_loss dfl_loss Instances Size\n",
+ " 1/3 0G 1.166 1.386 1.217 215 640: 100%|██████████| 8/8 [00:49<00:00, 6.20s/it]\n",
+ " Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 4/4 [00:16<00:00, 4.15s/it]\n",
+ " all 128 929 0.65 0.549 0.617 0.456\n",
+ "\n",
+ " Epoch GPU_mem box_loss cls_loss dfl_loss Instances Size\n",
+ " 2/3 0G 1.182 1.427 1.255 185 640: 100%|██████████| 8/8 [00:48<00:00, 6.08s/it]\n",
+ " Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 4/4 [00:14<00:00, 3.67s/it]\n",
+ " all 128 929 0.677 0.584 0.646 0.48\n",
+ "\n",
+ " Epoch GPU_mem box_loss cls_loss dfl_loss Instances Size\n",
+ " 3/3 0G 1.173 1.319 1.25 246 640: 100%|██████████| 8/8 [00:46<00:00, 5.81s/it]\n",
+ " Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 4/4 [00:17<00:00, 4.39s/it]\n",
+ " all 128 929 0.679 0.59 0.657 0.488\n",
+ "\n",
+ "3 epochs completed in 0.055 hours.\n",
+ "Optimizer stripped from runs/detect/train4/weights/last.pt, 6.5MB\n",
+ "Optimizer stripped from runs/detect/train4/weights/best.pt, 6.5MB\n",
+ "\n",
+ "Validating runs/detect/train4/weights/best.pt...\n",
+ "Ultralytics YOLOv8.0.48 🚀 Python-3.8.12 torch-1.13.1 CPU\n",
+ "Model summary (fused): 168 layers, 3151904 parameters, 0 gradients, 8.7 GFLOPs\n",
+ " Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 4/4 [00:10<00:00, 2.62s/it]\n",
+ " all 128 929 0.681 0.589 0.657 0.489\n",
+ " person 128 254 0.754 0.669 0.76 0.539\n",
+ " bicycle 128 6 0.775 0.333 0.386 0.264\n",
+ " car 128 46 0.806 0.217 0.33 0.186\n",
+ " motorcycle 128 5 0.687 0.8 0.92 0.752\n",
+ " airplane 128 6 0.741 0.958 0.955 0.75\n",
+ " bus 128 7 0.751 0.714 0.734 0.675\n",
+ " train 128 3 0.679 1 0.913 0.855\n",
+ " truck 128 12 0.967 0.5 0.527 0.354\n",
+ " boat 128 6 0.344 0.167 0.429 0.291\n",
+ " traffic light 128 14 0.661 0.214 0.226 0.142\n",
+ " stop sign 128 2 0.688 1 0.995 0.721\n",
+ " bench 128 9 0.76 0.556 0.683 0.532\n",
+ " bird 128 16 0.88 0.875 0.956 0.618\n",
+ " cat 128 4 0.684 1 0.895 0.697\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " dog 128 9 0.556 0.778 0.803 0.578\n",
+ " horse 128 2 0.688 1 0.995 0.484\n",
+ " elephant 128 17 0.773 0.941 0.942 0.751\n",
+ " bear 128 1 0.537 1 0.995 0.995\n",
+ " zebra 128 4 0.861 1 0.995 0.965\n",
+ " giraffe 128 9 0.781 1 0.973 0.708\n",
+ " backpack 128 6 0.648 0.333 0.467 0.263\n",
+ " umbrella 128 18 0.658 0.535 0.689 0.453\n",
+ " handbag 128 19 1 0.117 0.27 0.149\n",
+ " tie 128 7 0.775 0.714 0.739 0.518\n",
+ " suitcase 128 4 0.773 0.864 0.895 0.569\n",
+ " frisbee 128 5 0.763 0.8 0.759 0.655\n",
+ " skis 128 1 0.748 1 0.995 0.497\n",
+ " snowboard 128 7 0.469 0.714 0.666 0.48\n",
+ " sports ball 128 6 0.71 0.421 0.556 0.276\n",
+ " kite 128 10 0.645 0.547 0.554 0.204\n",
+ " baseball bat 128 4 0.448 0.421 0.252 0.127\n",
+ " baseball glove 128 7 0.692 0.429 0.431 0.303\n",
+ " skateboard 128 5 0.769 0.6 0.6 0.4\n",
+ " tennis racket 128 7 0.731 0.395 0.534 0.331\n",
+ " bottle 128 18 0.516 0.444 0.465 0.274\n",
+ " wine glass 128 16 0.514 0.562 0.597 0.346\n",
+ " cup 128 36 0.661 0.325 0.434 0.317\n",
+ " fork 128 6 0.64 0.167 0.191 0.184\n",
+ " knife 128 16 0.647 0.5 0.61 0.366\n",
+ " spoon 128 22 0.575 0.185 0.347 0.202\n",
+ " bowl 128 28 0.567 0.643 0.654 0.532\n",
+ " banana 128 1 0 0 0.199 0.0647\n",
+ " sandwich 128 2 1 0.778 0.995 0.995\n",
+ " orange 128 4 1 0.393 0.828 0.535\n",
+ " broccoli 128 11 0.48 0.254 0.277 0.23\n",
+ " carrot 128 24 0.692 0.561 0.74 0.482\n",
+ " hot dog 128 2 0.462 1 0.995 0.946\n",
+ " pizza 128 5 0.747 1 0.995 0.834\n",
+ " donut 128 14 0.633 1 0.941 0.857\n",
+ " cake 128 4 0.852 1 0.995 0.89\n",
+ " chair 128 35 0.571 0.543 0.489 0.297\n",
+ " couch 128 6 0.362 0.5 0.599 0.471\n",
+ " potted plant 128 14 0.649 0.643 0.717 0.48\n",
+ " bed 128 3 0.911 1 0.995 0.798\n",
+ " dining table 128 13 0.507 0.615 0.503 0.403\n",
+ " toilet 128 2 1 0.949 0.995 0.946\n",
+ " tv 128 2 0.483 0.5 0.745 0.696\n",
+ " laptop 128 3 1 0 0.5 0.399\n",
+ " mouse 128 2 1 0 0.0483 0.00483\n",
+ " remote 128 8 0.809 0.5 0.582 0.514\n",
+ " cell phone 128 8 0 0 0.0877 0.046\n",
+ " microwave 128 3 0.595 1 0.753 0.62\n",
+ " oven 128 5 0.518 0.4 0.41 0.302\n",
+ " sink 128 6 0.312 0.167 0.385 0.198\n",
+ " refrigerator 128 5 0.631 0.4 0.654 0.525\n",
+ " book 128 29 0.711 0.172 0.398 0.227\n",
+ " clock 128 9 0.885 0.889 0.92 0.77\n",
+ " vase 128 2 0.489 1 0.828 0.795\n",
+ " scissors 128 1 1 0 0.497 0.149\n",
+ " teddy bear 128 21 0.749 0.569 0.643 0.412\n",
+ " toothbrush 128 5 1 0.569 0.826 0.527\n",
+ "Speed: 1.0ms preprocess, 70.0ms inference, 0.0ms loss, 1.8ms postprocess per image\n",
+ "Results saved to \u001b[1mruns/detect/train4\u001b[0m\n"
+ ]
+ }
+ ],
+ "source": [
+ "op._model.train(data=\"coco128.yaml\", epochs=3)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "f9004cb5",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}