Skip to content

Commit

Permalink
refactor cifar10 example
Browse files Browse the repository at this point in the history
  • Loading branch information
tianweidut committed Sep 13, 2022
1 parent 1b20a11 commit 71cb40f
Show file tree
Hide file tree
Showing 17 changed files with 86 additions and 107 deletions.
1 change: 1 addition & 0 deletions client/starwhale/core/dataset/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def astype(self) -> t.Dict[str, t.Any]:
"mime_type": self.mime_type,
"shape": self.shape,
"encoding": self.encoding,
"display_name": self.display_name,
}

def asdict(self, ignore_keys: t.Optional[t.List[str]] = None) -> t.Dict[str, t.Any]:
Expand Down
2 changes: 2 additions & 0 deletions example/cifar10/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
data/
models/
5 changes: 5 additions & 0 deletions example/cifar10/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
.POHNY: train
train:
mkdir -p models data
python3 cifar/train.py

File renamed without changes.
54 changes: 54 additions & 0 deletions example/cifar10/cifar/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import io
import pickle
import typing as t
from pathlib import Path

from PIL import Image as PILImage

from starwhale.api.dataset import Image, MIMEType, BuildExecutor

ROOT_DIR = Path(__file__).parent.parent / "data" / "cifar-10-batches-py"
TRAIN_DATASET_PATHS = [ROOT_DIR / f"data_batch_{i}" for i in range(1, 6)]
TEST_DATASET_PATHS = [ROOT_DIR / "test_batch"]


def parse_meta() -> t.Dict[str, t.Any]:
with (ROOT_DIR / "batches.meta").open("rb") as f:
return pickle.load(f)


dataset_meta = parse_meta()


def _iter_item(paths: t.List[Path]) -> t.Generator[t.Tuple[t.Any, t.Dict], None, None]:
for path in paths:
with path.open("rb") as f:
content = pickle.load(f, encoding="bytes")
for data, label, filename in zip(
content[b"data"], content[b"labels"], content[b"filenames"]
):
annotations = {
"label": label,
"label_display_name": dataset_meta["label_names"][label],
}

image_array = data.reshape(3, 32, 32).transpose(1, 2, 0)
image_bytes = io.BytesIO()
PILImage.fromarray(image_array).save(image_bytes, format="PNG")

yield Image(
fp=image_bytes.getvalue(),
display_name=filename.decode(),
shape=image_array.shape,
mime_type=MIMEType.PNG,
), annotations


class CIFAR10TrainBuildExecutor(BuildExecutor):
def iter_item(self) -> t.Generator[t.Tuple[t.Any, t.Any], None, None]:
return _iter_item(TRAIN_DATASET_PATHS)


class CIFAR10TestBuildExecutor(BuildExecutor):
def iter_item(self) -> t.Generator[t.Tuple[t.Any, t.Any], None, None]:
return _iter_item(TEST_DATASET_PATHS)
File renamed without changes.
File renamed without changes.
34 changes: 18 additions & 16 deletions example/cifar10/code/train.py → example/cifar10/cifar/train.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,37 @@
import os

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from model import Net
from torch.optim.lr_scheduler import StepLR
import os

# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
_ROOT_DIR = os.path.dirname(os.path.dirname(__file__))
_MODEL_PATH = os.path.join(_ROOT_DIR, "../models/cifar_net.pth")
_DATA_DIR = os.path.join(_ROOT_DIR, "data")


def main():
def train():
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
batch_size = 4
trainset = torchvision.datasets.CIFAR10(root='../data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=2)
train_set = torchvision.datasets.CIFAR10(
root=_DATA_DIR, train=True, download=True, transform=transform
)
train_loader = torch.utils.data.DataLoader(
train_set, batch_size=batch_size, shuffle=True, num_workers=2
)
net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
for epoch in range(10): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
for i, data in enumerate(train_loader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
# zero the parameter gradients
Expand All @@ -37,17 +41,15 @@ def main():
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 2000 == 1999: # print every 2000 mini-batches
print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
if i % 2000 == 1999: # print every 2000 mini-batches
print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
running_loss = 0.0
scheduler.step()

print('Finished Training')

print("Finished Training")
torch.save(net.state_dict(), _MODEL_PATH)


if __name__ == "__main__":
main()
train()
16 changes: 0 additions & 16 deletions example/cifar10/code/data_slicer.py

This file was deleted.

48 changes: 0 additions & 48 deletions example/cifar10/code/test.py

This file was deleted.

Empty file removed example/cifar10/config/__init__.py
Empty file.
Empty file removed example/cifar10/config/config.py
Empty file.
Empty file.
12 changes: 5 additions & 7 deletions example/cifar10/dataset.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
name: cifar10

process: code.data_slicer:CIFAR10Slicer
pip_req: requirements.txt
#name: cifar10-train
#process: cifar.dataset:CIFAR10TrainBuildExecutor

name: cifar10-test
process: cifar.dataset:CIFAR10TestBuildExecutor
desc: CIFAR10 data and label test dataset
tag:
- bin

attr:
alignment_size: 4k
volume_size: 2M
volume_size: 10M
9 changes: 1 addition & 8 deletions example/cifar10/model.yaml
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
version: 1.0
name: cifar_net

model:
- models/cifar_net.pth

config:
- config/hyperparam.json

run:
ppl: code.ppl:CIFAR10Inference
ppl: cifar.ppl:CIFAR10Inference

desc: cifar10 by pytorch

tag:
- multi_classification
6 changes: 0 additions & 6 deletions example/cifar10/requirements.txt

This file was deleted.

6 changes: 0 additions & 6 deletions example/cifar10/runtime.yaml

This file was deleted.

0 comments on commit 71cb40f

Please sign in to comment.