From be7b0201de8011419bc84f3fdce468d539337828 Mon Sep 17 00:00:00 2001 From: tianwei Date: Thu, 30 Nov 2023 16:21:25 +0800 Subject: [PATCH] add example for object detection --- example/object-detection/README.md | 49 +++++++ example/object-detection/datasets/.gitignore | 1 + example/object-detection/datasets/coco128.py | 76 ++++++++++ .../object-detection/datasets/coco_val2017.py | 76 ++++++++++ example/object-detection/datasets/utils.py | 135 ++++++++++++++++++ .../object-detection/models/yolo/.gitignore | 3 + .../object-detection/models/yolo/.swignore | 2 + .../object-detection/models/yolo/README.md | 0 example/object-detection/models/yolo/build.py | 56 ++++++++ .../object-detection/models/yolo/consts.py | 82 +++++++++++ .../models/yolo/evaluation.py | 112 +++++++++++++++ .../object-detection/runtime/requirements.txt | 5 + example/object-detection/runtime/runtime.yaml | 7 + 13 files changed, 604 insertions(+) create mode 100644 example/object-detection/README.md create mode 100644 example/object-detection/datasets/.gitignore create mode 100644 example/object-detection/datasets/coco128.py create mode 100644 example/object-detection/datasets/coco_val2017.py create mode 100644 example/object-detection/datasets/utils.py create mode 100644 example/object-detection/models/yolo/.gitignore create mode 100644 example/object-detection/models/yolo/.swignore create mode 100644 example/object-detection/models/yolo/README.md create mode 100644 example/object-detection/models/yolo/build.py create mode 100644 example/object-detection/models/yolo/consts.py create mode 100644 example/object-detection/models/yolo/evaluation.py create mode 100644 example/object-detection/runtime/requirements.txt create mode 100644 example/object-detection/runtime/runtime.yaml diff --git a/example/object-detection/README.md b/example/object-detection/README.md new file mode 100644 index 0000000000..363ff0a181 --- /dev/null +++ b/example/object-detection/README.md @@ -0,0 +1,49 @@ +Object Detection +====== + +Object detection is a computer vision technique for locating instances of objects in images or videos. Object detection algorithms typically leverage machine learning or deep learning to produce meaningful results. + +In these examples, we will use Starwhale to evaluate a set of object detection models on COCO datasets. + +Thanks to [ultralytics](https://github.com/ultralytics/ultralytics), it makes Starwhale Model Evaluation on YOLO easily. + +Links +------ + +- Github Example Code: +- Starwhale Cloud Demo: + +What we learn +------ + +- build Starwhale Dataset by Starwhale Python SDK and use Starwhale Dataset Web Viewer. + +Models +------ + +- [YOLO](https://docs.ultralytics.com/): We will compare YOLOv8-{n,s,m,l,x} and YOLOv6-{n,s,m,l,l6} model evaluations. + +Datasets +------ + +- [COCO128](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/datasets/coco128.yaml) + + - Introduction: Ultralytics COCO8 is a small, but versatile object detection dataset composed of the first 128 images of the COCO train 2017 set. This dataset is ideal for testing and debugging object detection models. + - Size: Validation images 128. + - Dataset build command: + + ```bash + swcli runtime activate object-detection + python3 datasets/coco128.py + ``` + +- [COCO_val2017](https://cocodataset.org/#download) + + - Introduction: The COCO (Common Objects in Context) dataset is a large-scale object detection, segmentation, and captioning dataset. It is designed to encourage research on a wide variety of object categories and is commonly used for benchmarking computer vision models. The dataset comprises 80 object categories. + - Size: Validation images 5,000. + - Dataset build command: + + ```bash + swcli runtime activate object-detection + python3 datasets/coco_val2017.py + ``` diff --git a/example/object-detection/datasets/.gitignore b/example/object-detection/datasets/.gitignore new file mode 100644 index 0000000000..07f43b870e --- /dev/null +++ b/example/object-detection/datasets/.gitignore @@ -0,0 +1 @@ +data/* \ No newline at end of file diff --git a/example/object-detection/datasets/coco128.py b/example/object-detection/datasets/coco128.py new file mode 100644 index 0000000000..57c3fb0b0f --- /dev/null +++ b/example/object-detection/datasets/coco128.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from pathlib import Path + +from utils import download, extract_zip, get_name_by_coco_category_id + +from starwhale import Image, dataset, BoundingBox, init_logger +from starwhale.utils import console + +init_logger(3) + +ROOT = Path(__file__).parent +DATA_DIR = ROOT / "data" / "coco128" + +# Copy from https://www.kaggle.com/datasets/ultralytics/coco128. + + +def build() -> None: + _zip_path = DATA_DIR / "coco128.zip" + download("https://ultralytics.com/assets/coco128.zip", _zip_path) + extract_zip( + _zip_path, DATA_DIR, DATA_DIR / "coco129/images/train2017/000000000650.jpg" + ) + + with dataset("coco128") as ds: + for img_path in (DATA_DIR / "coco128/images/train2017").glob("*.jpg"): + name = img_path.name.split(".jpg")[0] + + # YOLO Darknet format: https://docs.plainsight.ai/labels/exporting-labels/yolo + # Format: + # Meaning: object-class> - zero-based index representing the class in obj.names from 0 to (classes-1). + # - float values relative to width and height of image, it can be equal from (0.0 to 1.0]. + # = / + # = / + + annotations = [] + image = Image(img_path) + i_width, i_height = image.to_pil().size + + label_path = DATA_DIR / "coco128/labels/train2017" / f"{name}.txt" + if not label_path.exists(): + continue + + for line in label_path.read_text().splitlines(): + class_id, x, y, w, h = line.split() + class_id, x, y, w, h = ( + int(class_id), + float(x), + float(y), + float(w), + float(h), + ) + annotations.append( + { + "class_id": class_id, + "class_name": get_name_by_coco_category_id(class_id), + "darknet_bbox": [x, y, w, h], + "bbox": BoundingBox( + x=(x - w / 2) * i_width, + y=(y - h / 2) * i_height, + width=w * i_width, + height=h * i_height, + ), + } + ) + + ds[name] = {"image": image, "annotations": annotations} + + console.print("commit dataset...") + ds.commit() + + console.print(f"{ds} has been built successfully!") + + +if __name__ == "__main__": + build() diff --git a/example/object-detection/datasets/coco_val2017.py b/example/object-detection/datasets/coco_val2017.py new file mode 100644 index 0000000000..570340ceb9 --- /dev/null +++ b/example/object-detection/datasets/coco_val2017.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +import json +from pathlib import Path +from collections import defaultdict + +from tqdm import tqdm +from utils import download, extract_zip, get_name_by_coco_category_id +from ultralytics.data.converter import coco91_to_coco80_class + +from starwhale import Image, dataset, init_logger +from starwhale.utils import console +from starwhale.base.data_type import BoundingBox + +init_logger(3) + +ROOT = Path(__file__).parent +DATA_DIR = ROOT / "data" / "coco2017" + +# The coco2017 val set is from https://cocodataset.org/#download. + + +def build() -> None: + _zip_path = DATA_DIR / "val2017.zip" + download( + "https://starwhale-examples.oss-cn-beijing.aliyuncs.com/dataset/coco2017/val2017.zip", + _zip_path, + ) + extract_zip(_zip_path, DATA_DIR, DATA_DIR / "val2017/000000000139.jpg") + + _zip_path = DATA_DIR / "annotations_trainval2017.zip" + download( + "https://starwhale-examples.oss-cn-beijing.aliyuncs.com/dataset/coco2017/annotations_trainval2017.zip", + _zip_path, + ) + json_path = DATA_DIR / "annotations/instances_val2017.json" + extract_zip(_zip_path, DATA_DIR, json_path) + + coco_classes = coco91_to_coco80_class() + + with json_path.open() as f: + content = json.load(f) + annotations = defaultdict(list) + for ann in content["annotations"]: + class_id = coco_classes[ann["category_id"] - 1] + annotations[ann["image_id"]].append( + { + "bbox": BoundingBox(*ann["bbox"]), + "class_id": class_id, + "class_name": get_name_by_coco_category_id(class_id), + } + ) + + with dataset("coco_val2017") as ds: + for image in tqdm(content["images"]): + name = image["file_name"].split(".jpg")[0] + for ann in annotations[image["id"]]: + bbox = ann["bbox"] + ann["darknet_bbox"] = [ + (bbox.x + bbox.width / 2) / image["width"], + (bbox.y + bbox.height / 2) / image["height"], + bbox.width / image["width"], + bbox.height / image["height"], + ] + ds[name] = { + "image": Image(DATA_DIR / "val2017" / image["file_name"]), + "annotations": annotations[image["id"]], + } + console.print("commit dataset...") + ds.commit() + + console.print(f"{ds} has been built successfully!") + + +if __name__ == "__main__": + build() diff --git a/example/object-detection/datasets/utils.py b/example/object-detection/datasets/utils.py new file mode 100644 index 0000000000..eaa64a4c56 --- /dev/null +++ b/example/object-detection/datasets/utils.py @@ -0,0 +1,135 @@ +from __future__ import annotations + +import zipfile +from pathlib import Path + +import requests +from tqdm import tqdm + +from starwhale.utils import console + +_COCO_CLASSES_MAP = { + 0: "person", + 1: "bicycle", + 2: "car", + 3: "motorcycle", + 4: "airplane", + 5: "bus", + 6: "train", + 7: "truck", + 8: "boat", + 9: "traffic light", + 10: "fire hydrant", + 11: "stop sign", + 12: "parking meter", + 13: "bench", + 14: "bird", + 15: "cat", + 16: "dog", + 17: "horse", + 18: "sheep", + 19: "cow", + 20: "elephant", + 21: "bear", + 22: "zebra", + 23: "giraffe", + 24: "backpack", + 25: "umbrella", + 26: "handbag", + 27: "tie", + 28: "suitcase", + 29: "frisbee", + 30: "skis", + 31: "snowboard", + 32: "sports ball", + 33: "kite", + 34: "baseball bat", + 35: "baseball glove", + 36: "skateboard", + 37: "surfboard", + 38: "tennis racket", + 39: "bottle", + 40: "wine glass", + 41: "cup", + 42: "fork", + 43: "knife", + 44: "spoon", + 45: "bowl", + 46: "banana", + 47: "apple", + 48: "sandwich", + 49: "orange", + 50: "broccoli", + 51: "carrot", + 52: "hot dog", + 53: "pizza", + 54: "donut", + 55: "cake", + 56: "chair", + 57: "couch", + 58: "potted plant", + 59: "bed", + 60: "dining table", + 61: "toilet", + 62: "tv", + 63: "laptop", + 64: "mouse", + 65: "remote", + 66: "keyboard", + 67: "cell phone", + 68: "microwave", + 69: "oven", + 70: "toaster", + 71: "sink", + 72: "refrigerator", + 73: "book", + 74: "clock", + 75: "vase", + 76: "scissors", + 77: "teddy bear", + 78: "hair drier", + 79: "toothbrush", +} + + +def get_name_by_coco_category_id(category_id: int | None) -> str: + return ( + _COCO_CLASSES_MAP[category_id] if category_id is not None else "uncategorized" + ) + + +def extract_zip(from_path: Path, to_path: Path, chk_path: Path) -> None: + if chk_path.exists(): + console.log(f"skip extract {from_path}, dir {chk_path} already exists") + return + + with zipfile.ZipFile(from_path, "r", zipfile.ZIP_STORED) as z: + for file in tqdm( + iterable=z.namelist(), + total=len(z.namelist()), + desc=f"extract {from_path.name}", + ): + z.extract(member=file, path=to_path) + + +def download(url: str, to_path: Path) -> None: + if to_path.exists(): + console.log(f"skip download {url}, file {to_path} already exists") + return + + to_path.parent.mkdir(parents=True, exist_ok=True) + + with requests.get(url, timeout=60, stream=True) as r: + r.raise_for_status() + size = int(r.headers.get("content-length", 0)) + with tqdm( + iterable=r.iter_content(chunk_size=1024), + total=size, + unit="B", + unit_scale=True, + desc=f"download {url}", + ) as pbar: + with open(to_path, "wb") as f: + for chunk in pbar: + f.write(chunk) + pbar.update(len(chunk)) diff --git a/example/object-detection/models/yolo/.gitignore b/example/object-detection/models/yolo/.gitignore new file mode 100644 index 0000000000..dfed2effa0 --- /dev/null +++ b/example/object-detection/models/yolo/.gitignore @@ -0,0 +1,3 @@ +checkpoints/* +runs/* +flagged/* diff --git a/example/object-detection/models/yolo/.swignore b/example/object-detection/models/yolo/.swignore new file mode 100644 index 0000000000..12ce64a37e --- /dev/null +++ b/example/object-detection/models/yolo/.swignore @@ -0,0 +1,2 @@ +runs/* +checkpoints/cache/* diff --git a/example/object-detection/models/yolo/README.md b/example/object-detection/models/yolo/README.md new file mode 100644 index 0000000000..e69de29bb2 diff --git a/example/object-detection/models/yolo/build.py b/example/object-detection/models/yolo/build.py new file mode 100644 index 0000000000..dd53ba7b69 --- /dev/null +++ b/example/object-detection/models/yolo/build.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +import sys +from pathlib import Path + +from starwhale import model as starwhale_model + +ROOT = Path(__file__).parent +CHECKPOINTS_DIR = ROOT / "checkpoints" + +SUPPORT_MODELS = ( + "yolov8n", + "yolov8s", + "yolov8m", + "yolov8l", + "yolov8x", + "yolov6-n", + "yolov6-s", + "yolov6-m", + "yolov6-l", + "yolov6-l6", +) + + +def build(model: str) -> None: + print(f"start to build {model} yolo model...") + fpath = CHECKPOINTS_DIR / "cache" / f"{model}.pt" + if not fpath.exists(): + from torch.hub import download_url_to_file + + fpath.parent.mkdir(parents=True, exist_ok=True) + download_url_to_file( + url=f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{model}.pt", + dst=str(fpath), + ) + + (CHECKPOINTS_DIR / ".model").write_text(model) + for pt in CHECKPOINTS_DIR.glob("*.pt"): + pt.unlink() + fpath.link_to(CHECKPOINTS_DIR / f"{model}.pt") + + starwhale_model.build(name=model, modules=["evaluation"]) + + +if __name__ == "__main__": + if len(sys.argv[1:]) == 0: + print(f"please specify model name, supported: {SUPPORT_MODELS}") + sys.exit(1) + elif sys.argv[1] == "all": + print("build all supported yolo models") + models = SUPPORT_MODELS + else: + models = [sys.argv[1]] + + for model in models: + build(model) diff --git a/example/object-detection/models/yolo/consts.py b/example/object-detection/models/yolo/consts.py new file mode 100644 index 0000000000..eae5df4265 --- /dev/null +++ b/example/object-detection/models/yolo/consts.py @@ -0,0 +1,82 @@ +COCO_CLASSES_MAP = { + 0: "person", + 1: "bicycle", + 2: "car", + 3: "motorcycle", + 4: "airplane", + 5: "bus", + 6: "train", + 7: "truck", + 8: "boat", + 9: "traffic light", + 10: "fire hydrant", + 11: "stop sign", + 12: "parking meter", + 13: "bench", + 14: "bird", + 15: "cat", + 16: "dog", + 17: "horse", + 18: "sheep", + 19: "cow", + 20: "elephant", + 21: "bear", + 22: "zebra", + 23: "giraffe", + 24: "backpack", + 25: "umbrella", + 26: "handbag", + 27: "tie", + 28: "suitcase", + 29: "frisbee", + 30: "skis", + 31: "snowboard", + 32: "sports ball", + 33: "kite", + 34: "baseball bat", + 35: "baseball glove", + 36: "skateboard", + 37: "surfboard", + 38: "tennis racket", + 39: "bottle", + 40: "wine glass", + 41: "cup", + 42: "fork", + 43: "knife", + 44: "spoon", + 45: "bowl", + 46: "banana", + 47: "apple", + 48: "sandwich", + 49: "orange", + 50: "broccoli", + 51: "carrot", + 52: "hot dog", + 53: "pizza", + 54: "donut", + 55: "cake", + 56: "chair", + 57: "couch", + 58: "potted plant", + 59: "bed", + 60: "dining table", + 61: "toilet", + 62: "tv", + 63: "laptop", + 64: "mouse", + 65: "remote", + 66: "keyboard", + 67: "cell phone", + 68: "microwave", + 69: "oven", + 70: "toaster", + 71: "sink", + 72: "refrigerator", + 73: "book", + 74: "clock", + 75: "vase", + 76: "scissors", + 77: "teddy bear", + 78: "hair drier", + 79: "toothbrush", +} diff --git a/example/object-detection/models/yolo/evaluation.py b/example/object-detection/models/yolo/evaluation.py new file mode 100644 index 0000000000..ec328aa6c7 --- /dev/null +++ b/example/object-detection/models/yolo/evaluation.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +import os +import pickle +import typing as t +from pathlib import Path + +import torch +import gradio +from consts import COCO_CLASSES_MAP +from ultralytics import YOLO +from ultralytics.utils.metrics import DetMetrics +from ultralytics.models.yolo.detect.val import DetectionValidator + +from starwhale import Image, evaluation, Evaluation +from starwhale.api.service import api + +_g_model: YOLO | None = None + +CHECKPOINT_DIR = Path(__file__).parent / "checkpoints" + + +def _load_model(is_evaluation: bool = False) -> YOLO: + global _g_model + if _g_model is None: + # TODO: load model by build tag + if (CHECKPOINT_DIR / ".model").exists(): + model_name = (CHECKPOINT_DIR / ".model").read_text().strip() + else: + model_name = "yolov8n" + + _g_model = YOLO(CHECKPOINT_DIR / f"{model_name}.pt") + + return _g_model + + +@torch.no_grad() +@evaluation.predict(replicas=1, log_mode="plain", log_dataset_features=["image"]) +def predict_image(data: t.Dict, external: t.Dict) -> t.Dict: + img = data["image"].to_pil() + img.filename = f"{external['index']}.jpg" # workaround for ultralytics save image + + # TODO: support batch + # TODO: support arguments + model = _load_model(is_evaluation=True) + result = model.predict( + img, + save=True, + conf=float(os.environ.get("OBJECT_DETECTION_CONF", "0.1")), + iou=float(os.environ.get("OBJECT_DETECTION_IOU", "0.5")), + max_det=int(os.environ.get("OBJECT_DETECTION_MAX_DET", "300")), + )[0] + device = model.device + + label_classes = torch.as_tensor( + [[int(ann["class_id"])] for ann in data["annotations"]], device=device + ) + width, height = result.orig_shape + label_xyxy_bboxes = torch.as_tensor( + [ann["darknet_bbox"] for ann in data["annotations"]], device=device + ) * torch.tensor([width, height, width, height], device=device) + + correct_bboxes = DetectionValidator()._process_batch( + detections=result.boxes.data, + labels=torch.cat((label_classes, label_xyxy_bboxes), dim=1), + ) + # Copy from https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/detect/val.py#L89 + metric = ( + correct_bboxes, # correct bboxes + result.boxes.conf, # predicted confidences + result.boxes.cls, # predicted classes + label_classes.squeeze(-1), # label classes + ) + + return { + "speed": result.speed, + "predicted_image": Image(os.path.join(result.save_dir, img.filename)), + "ultralytics_metric": pickle.dumps(metric), + } + + +@evaluation.evaluate(needs=[predict_image]) +def summary_detection(predict_result_iter: t.Iterator) -> None: + stats = [] + for predict in predict_result_iter: + stats.append(pickle.loads(predict["output/ultralytics_metric"])) + stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)] + + metrics = DetMetrics(names=COCO_CLASSES_MAP) + metrics.process(*stats) + + e = Evaluation.from_context() + # starwhale datastore is not support dict key with parentheses + keys = [k.split("(B)")[0] for k in metrics.keys] + e.log_summary(dict(zip(keys, metrics.mean_results()))) + + for i, c in enumerate(metrics.ap_class_index): + e.log( + category="per_class", + id=COCO_CLASSES_MAP[c], + metrics=dict(zip(keys, metrics.class_result(i))), + ) + + +@torch.no_grad() +@api( + gradio.Image(type="filepath", label="Input Image"), + gradio.Image(type="filepath", label="Detected Image"), +) +def web_detect_image(file: str) -> Path | str: + result = _load_model().predict(file, save=True)[0] + return Path(result.save_dir) / Path(result.path).name diff --git a/example/object-detection/runtime/requirements.txt b/example/object-detection/runtime/requirements.txt new file mode 100644 index 0000000000..6d3bcdeef9 --- /dev/null +++ b/example/object-detection/runtime/requirements.txt @@ -0,0 +1,5 @@ +tqdm +requests +torch==2.0.1 +gradio +ultralytics==8.0.220 \ No newline at end of file diff --git a/example/object-detection/runtime/runtime.yaml b/example/object-detection/runtime/runtime.yaml new file mode 100644 index 0000000000..6e15c8acac --- /dev/null +++ b/example/object-detection/runtime/runtime.yaml @@ -0,0 +1,7 @@ +name: object-detection +mode: venv +environment: + cuda: "11.7" + python: "3.10" +dependencies: + - requirements.txt