Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

example: refactor cifar10 example #1177

Merged
merged 1 commit into from
Sep 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions client/starwhale/core/dataset/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def astype(self) -> t.Dict[str, t.Any]:
"mime_type": self.mime_type,
"shape": self.shape,
"encoding": self.encoding,
"display_name": self.display_name,
}

def asdict(self, ignore_keys: t.Optional[t.List[str]] = None) -> t.Dict[str, t.Any]:
Expand Down
2 changes: 2 additions & 0 deletions example/cifar10/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
data/
models/
5 changes: 5 additions & 0 deletions example/cifar10/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
.POHNY: train
train:
mkdir -p models data
python3 cifar/train.py

File renamed without changes.
54 changes: 54 additions & 0 deletions example/cifar10/cifar/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import io
import pickle
import typing as t
from pathlib import Path

from PIL import Image as PILImage

from starwhale.api.dataset import Image, MIMEType, BuildExecutor

ROOT_DIR = Path(__file__).parent.parent / "data" / "cifar-10-batches-py"
TRAIN_DATASET_PATHS = [ROOT_DIR / f"data_batch_{i}" for i in range(1, 6)]
TEST_DATASET_PATHS = [ROOT_DIR / "test_batch"]


def parse_meta() -> t.Dict[str, t.Any]:
with (ROOT_DIR / "batches.meta").open("rb") as f:
return pickle.load(f)


dataset_meta = parse_meta()


def _iter_item(paths: t.List[Path]) -> t.Generator[t.Tuple[t.Any, t.Dict], None, None]:
for path in paths:
with path.open("rb") as f:
content = pickle.load(f, encoding="bytes")
for data, label, filename in zip(
content[b"data"], content[b"labels"], content[b"filenames"]
):
annotations = {
"label": label,
"label_display_name": dataset_meta["label_names"][label],
}

image_array = data.reshape(3, 32, 32).transpose(1, 2, 0)
image_bytes = io.BytesIO()
PILImage.fromarray(image_array).save(image_bytes, format="PNG")

yield Image(
fp=image_bytes.getvalue(),
display_name=filename.decode(),
shape=image_array.shape,
mime_type=MIMEType.PNG,
), annotations


class CIFAR10TrainBuildExecutor(BuildExecutor):
def iter_item(self) -> t.Generator[t.Tuple[t.Any, t.Any], None, None]:
return _iter_item(TRAIN_DATASET_PATHS)


class CIFAR10TestBuildExecutor(BuildExecutor):
def iter_item(self) -> t.Generator[t.Tuple[t.Any, t.Any], None, None]:
return _iter_item(TEST_DATASET_PATHS)
File renamed without changes.
65 changes: 65 additions & 0 deletions example/cifar10/cifar/ppl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import io
from pathlib import Path

import numpy as np
import torch
from PIL import Image as PILImage
from torchvision import transforms

from starwhale.api.job import Context
from starwhale.api.model import PipelineHandler
from starwhale.api.metric import multi_classification
from starwhale.api.dataset import Image

from .model import Net

ROOTDIR = Path(__file__).parent.parent


class CIFAR10Inference(PipelineHandler):
def __init__(self, context: Context) -> None:
super().__init__(context=context)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = self._load_model(self.device)

def ppl(self, img: Image, **kw):
data_tensor = self._pre(img)
output = self.model(data_tensor)
return self._post(output)

@multi_classification(
confusion_matrix_normalize="all",
show_hamming_loss=True,
show_cohen_kappa_score=True,
show_roc_auc=True,
all_labels=[i for i in range(0, 10)],
)
def cmp(self, ppl_result):
result, label, pr = [], [], []
for _data in ppl_result:
label.append(_data["annotations"]["label"])
result.extend(_data["result"][0])
pr.extend(_data["result"][1])
return label, result, pr

def _pre(self, input: Image) -> torch.Tensor:
_image = PILImage.open(io.BytesIO(input.to_bytes()))
_image = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)(_image)
return torch.stack([_image]).to(self.device)

def _post(self, input):
pred_value = input.argmax(1).flatten().tolist()
probability_matrix = np.exp(input.tolist()).tolist()
return pred_value, probability_matrix

def _load_model(self, device):
model = Net().to(device)
model.load_state_dict(torch.load(str(ROOTDIR / "models" / "cifar_net.pth")))
model.eval()
print("load cifar_net model, start to inference...")
return model
34 changes: 18 additions & 16 deletions example/cifar10/code/train.py → example/cifar10/cifar/train.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,37 @@
import os

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from model import Net
from torch.optim.lr_scheduler import StepLR
import os

# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
_ROOT_DIR = os.path.dirname(os.path.dirname(__file__))
_MODEL_PATH = os.path.join(_ROOT_DIR, "../models/cifar_net.pth")
_DATA_DIR = os.path.join(_ROOT_DIR, "data")


def main():
def train():
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
batch_size = 4
trainset = torchvision.datasets.CIFAR10(root='../data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=2)
train_set = torchvision.datasets.CIFAR10(
root=_DATA_DIR, train=True, download=True, transform=transform
)
train_loader = torch.utils.data.DataLoader(
train_set, batch_size=batch_size, shuffle=True, num_workers=2
)
net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
for epoch in range(10): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
for i, data in enumerate(train_loader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
# zero the parameter gradients
Expand All @@ -37,17 +41,15 @@ def main():
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 2000 == 1999: # print every 2000 mini-batches
print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
if i % 2000 == 1999: # print every 2000 mini-batches
print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
running_loss = 0.0
scheduler.step()

print('Finished Training')

print("Finished Training")
torch.save(net.state_dict(), _MODEL_PATH)


if __name__ == "__main__":
main()
train()
16 changes: 0 additions & 16 deletions example/cifar10/code/data_slicer.py

This file was deleted.

81 changes: 0 additions & 81 deletions example/cifar10/code/ppl.py

This file was deleted.

48 changes: 0 additions & 48 deletions example/cifar10/code/test.py

This file was deleted.

Empty file removed example/cifar10/config/__init__.py
Empty file.
Empty file removed example/cifar10/config/config.py
Empty file.
Empty file.
12 changes: 5 additions & 7 deletions example/cifar10/dataset.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
name: cifar10

process: code.data_slicer:CIFAR10Slicer
pip_req: requirements.txt
#name: cifar10-train
#process: cifar.dataset:CIFAR10TrainBuildExecutor

name: cifar10-test
process: cifar.dataset:CIFAR10TestBuildExecutor
desc: CIFAR10 data and label test dataset
tag:
- bin

attr:
alignment_size: 4k
volume_size: 2M
volume_size: 10M
9 changes: 1 addition & 8 deletions example/cifar10/model.yaml
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
version: 1.0
name: cifar_net

model:
- models/cifar_net.pth

config:
- config/hyperparam.json

run:
ppl: code.ppl:CIFAR10Inference
ppl: cifar.ppl:CIFAR10Inference

desc: cifar10 by pytorch

tag:
- multi_classification
6 changes: 0 additions & 6 deletions example/cifar10/requirements.txt

This file was deleted.

Loading