-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
perf(auto-labeling): Enhance expandability (#69)
- Loading branch information
QIN2DIM
authored
Oct 25, 2023
1 parent
94646fc
commit 6aa00b2
Showing
6 changed files
with
122 additions
and
116 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,176 +1,160 @@ | ||
# -*- coding: utf-8 -*- | ||
# Time : 2023/10/20 17:28 | ||
# Time : 2023/10/24 5:39 | ||
# Author : QIN2DIM | ||
# GitHub : https://github.com/QIN2DIM | ||
# Description: zero-shot image classification | ||
# Description: | ||
from __future__ import annotations | ||
|
||
import logging | ||
import os | ||
import shutil | ||
import sys | ||
from dataclasses import dataclass | ||
from dataclasses import field | ||
from dataclasses import dataclass, field | ||
from datetime import datetime | ||
from pathlib import Path | ||
from typing import List, Tuple | ||
from typing import Tuple, List | ||
|
||
from PIL import Image | ||
from hcaptcha_challenger import split_prompt_message, label_cleaning | ||
from hcaptcha_challenger import ( | ||
DataLake, | ||
install, | ||
ModelHub, | ||
ZeroShotImageClassifier, | ||
register_pipline, | ||
) | ||
from tqdm import tqdm | ||
|
||
project_dir = Path(__file__).parent.parent | ||
db_dir = project_dir.joinpath("database2309") | ||
from flow_card import datalake_card | ||
|
||
logging.basicConfig( | ||
level=logging.INFO, stream=sys.stdout, format="%(asctime)s - %(levelname)s - %(message)s" | ||
) | ||
|
||
install(upgrade=True) | ||
|
||
|
||
@dataclass | ||
class AutoLabeling: | ||
positive_labels: List[str] = field(default_factory=list) | ||
candidate_labels: List[str] = field(default_factory=list) | ||
images_dir: Path = field(default=Path) | ||
pending_tasks: List[Path] = field(default_factory=list) | ||
""" | ||
Example: | ||
--- | ||
checkpoint = "laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K" | ||
# checkpoint = "QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336" | ||
1. Roughly observe the distribution of the dataset and design a DataLake for the challenge prompt. | ||
- ChallengePrompt: "Please click each image containing an off-road vehicle" | ||
- positive_labels --> ["off-road vehicle"] | ||
- negative_labels --> ["bicycle", "car"] | ||
output_dir: Path = None | ||
2. You can design them in batches and save them as YAML files, | ||
which the classifier can read and automatically DataLake | ||
def load_zero_shot_model(self): | ||
import torch | ||
from transformers import pipeline | ||
3. Note that positive_labels is a list, and you can specify multiple labels for this variable | ||
if the label pointed to by the prompt contains ambiguity。 | ||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
task = "zero-shot-image-classification" | ||
""" | ||
|
||
detector = pipeline(task=task, model=self.checkpoint, device=device, batch_size=8) | ||
input_dir: Path = field(default_factory=Path) | ||
pending_tasks: List[Path] = field(default_factory=list) | ||
tool: ZeroShotImageClassifier = field(default_factory=ZeroShotImageClassifier) | ||
|
||
return detector | ||
output_dir: Path = field(default_factory=Path) | ||
|
||
@classmethod | ||
def from_prompt(cls, positive_labels: List[str], candidate_labels: List[str], images_dir: Path): | ||
images_dir.mkdir(parents=True, exist_ok=True) | ||
limit: int = field(default=1) | ||
""" | ||
By default, all pictures in the specified folder are classified and moved, | ||
Specifies the limit used to limit the number of images for the operation. | ||
""" | ||
|
||
pending_tasks: List[Path] = [] | ||
for image_name in os.listdir(images_dir): | ||
image_path = images_dir.joinpath(image_name) | ||
@classmethod | ||
def from_datalake(cls, dl: DataLake, **kwargs): | ||
if not isinstance(dl.joined_dirs, Path): | ||
raise TypeError( | ||
f"The dataset joined_dirs needs to be passed in for auto-labeling. - {dl.joined_dirs=}" | ||
) | ||
if not dl.joined_dirs.exists(): | ||
raise ValueError(f"Specified dataset path does not exist - {dl.joined_dirs=}") | ||
|
||
input_dir = dl.joined_dirs | ||
pending_tasks = [] | ||
for image_name in os.listdir(input_dir): | ||
image_path = input_dir.joinpath(image_name) | ||
if image_path.is_file(): | ||
pending_tasks.append(image_path) | ||
|
||
return cls( | ||
positive_labels=positive_labels, | ||
candidate_labels=candidate_labels, | ||
images_dir=images_dir, | ||
pending_tasks=pending_tasks, | ||
) | ||
|
||
def valid(self): | ||
if not self.pending_tasks: | ||
print("No pending tasks") | ||
return | ||
if len(self.candidate_labels) <= 2: | ||
print(f">> Please enter at least three class names - {self.candidate_labels=}") | ||
return | ||
if (limit := kwargs.get("limit")) is None: | ||
limit = len(pending_tasks) | ||
elif not isinstance(limit, int) or limit < 1: | ||
raise ValueError(f"limit should be a positive integer greater than zero. - {limit=}") | ||
|
||
return True | ||
tool = ZeroShotImageClassifier.from_datalake(dl) | ||
return cls(tool=tool, input_dir=input_dir, pending_tasks=pending_tasks, limit=limit) | ||
|
||
def mkdir(self) -> Tuple[Path, Path]: | ||
__formats = ("%Y-%m-%d %H:%M:%S.%f", "%Y%m%d%H%M") | ||
now = datetime.strptime(str(datetime.now()), __formats[0]).strftime(__formats[1]) | ||
yes_dir = self.images_dir.joinpath(now, "yes") | ||
bad_dir = self.images_dir.joinpath(now, "bad") | ||
yes_dir = self.input_dir.joinpath(now, "yes") | ||
bad_dir = self.input_dir.joinpath(now, "bad") | ||
yes_dir.mkdir(parents=True, exist_ok=True) | ||
bad_dir.mkdir(parents=True, exist_ok=True) | ||
|
||
self.output_dir = yes_dir.parent | ||
|
||
return yes_dir, bad_dir | ||
|
||
def execute(self, limit: int | str = None): | ||
if not self.valid(): | ||
def execute(self, model): | ||
if not self.pending_tasks: | ||
logging.info("No pending tasks") | ||
return | ||
|
||
# Format datafolder | ||
yes_dir, bad_dir = self.mkdir() | ||
|
||
# Load zero-shot model | ||
detector = self.load_zero_shot_model() | ||
|
||
desc_in = f'"{self.input_dir.parent.name}/{self.input_dir.name}"' | ||
total = len(self.pending_tasks) | ||
desc_in = f'"{self.checkpoint}/{self.images_dir.name}"' | ||
if isinstance(limit, str) and limit == "all": | ||
limit = total | ||
else: | ||
limit = limit or total | ||
|
||
logging.info(f"load {self.tool.positive_labels=}") | ||
logging.info(f"load {self.tool.candidate_labels=}") | ||
|
||
with tqdm(total=total, desc=f"Labeling | {desc_in}") as progress: | ||
for image_path in self.pending_tasks[:limit]: | ||
for image_path in self.pending_tasks[: self.limit]: | ||
# The label at position 0 is the highest scoring target | ||
image = Image.open(image_path) | ||
results = self.tool(model, image) | ||
|
||
# Binary Image classification | ||
predictions = detector(image, candidate_labels=self.candidate_labels) | ||
|
||
# Move positive cases to yes/ | ||
# Move negative cases to bad/ | ||
if predictions[0]["label"] in self.positive_labels: | ||
# we're only dealing with binary classification tasks here | ||
if results[0]["label"] in self.tool.positive_labels: | ||
output_path = yes_dir.joinpath(image_path.name) | ||
else: | ||
output_path = bad_dir.joinpath(image_path.name) | ||
|
||
shutil.move(image_path, output_path) | ||
|
||
progress.update(1) | ||
|
||
|
||
@dataclass | ||
class DataGroup: | ||
positive_labels: List[str] | str | ||
joined_dirs: List[str] | ||
negative_labels: List[str] | ||
|
||
def __post_init__(self): | ||
if isinstance(self.positive_labels, str): | ||
self.positive_labels = [self.positive_labels] | ||
|
||
@property | ||
def input_dir(self): | ||
return db_dir.joinpath(*self.joined_dirs).absolute() | ||
|
||
def auto_labeling(self, **kwargs): | ||
pls = [] | ||
for pl in self.positive_labels: | ||
pl = pl.replace("_", " ") | ||
pl = split_prompt_message(label_cleaning(pl), "en") | ||
pls.append(pl) | ||
|
||
candidate_labels = pls.copy() | ||
def run(): | ||
modelhub = ModelHub.from_github_repo() | ||
modelhub.parse_objects() | ||
|
||
if isinstance(self.negative_labels, list) and len(self.negative_labels) != 0: | ||
candidate_labels.extend(self.negative_labels) | ||
model = register_pipline(modelhub) | ||
|
||
al = AutoLabeling.from_prompt(pls, candidate_labels, self.input_dir) | ||
al.execute(limit=kwargs.get("limit")) | ||
images_dir = Path(__file__).parent.parent.joinpath("database2309") | ||
|
||
return al | ||
|
||
|
||
def edit_in_the_common_cases(): | ||
# prompt to negative labels | ||
# input_dir = /[Project_dir]/database2309/*[joined_dirs] | ||
|
||
dg = DataGroup( | ||
positive_labels=["helicopter", "excavator"], | ||
joined_dirs=["motorized_machine"], | ||
negative_labels=["laptop", "chess", "plant", "natural landscape", "mountain"], | ||
) | ||
|
||
dg = DataGroup( | ||
positive_labels=["off road vehicle"], | ||
joined_dirs=["off_road_vehicle"], | ||
negative_labels=["bicycle", "car"], | ||
) | ||
|
||
nox = dg.auto_labeling(limit="all") | ||
if "win32" in sys.platform and nox.output_dir: | ||
os.startfile(nox.output_dir) | ||
for card in datalake_card: | ||
# Filter out the task cards we care about | ||
if "furniture" not in card["joined_dirs"]: | ||
continue | ||
# Generating a dataclass from serialized data | ||
dl = DataLake( | ||
positive_labels=card["positive_labels"], | ||
negative_labels=card["negative_labels"], | ||
joined_dirs=images_dir.joinpath(*card["joined_dirs"]), | ||
) | ||
# Starts an automatic labeling task | ||
al = AutoLabeling.from_datalake(dl) | ||
al.execute(model) | ||
# Automatically open output directory | ||
if "win32" in sys.platform and al.output_dir.is_dir(): | ||
os.startfile(al.output_dir) | ||
|
||
|
||
if __name__ == "__main__": | ||
edit_in_the_common_cases() | ||
run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# -*- coding: utf-8 -*- | ||
# Time : 2023/10/26 2:58 | ||
# Author : QIN2DIM | ||
# GitHub : https://github.com/QIN2DIM | ||
# Description: | ||
# Run `assets_manager.py` to get test data from GitHub issues | ||
|
||
flow_card = [ | ||
{ | ||
"positive_labels": ["off-road vehicle"], | ||
"negative_labels": ["car", "bicycle"], | ||
"joined_dirs": ["off_road_vehicle"], | ||
}, | ||
{ | ||
"positive_labels": ["furniture", "chair"], | ||
"negative_labels": ["guitar", "keyboard", "game tool", "headphones"], | ||
"joined_dirs": ["furniture"], | ||
}, | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters