-
Notifications
You must be signed in to change notification settings - Fork 0
Tutorial: MNIST
北海若 edited this page Feb 27, 2023
·
6 revisions
import torch
import torchvision
import torch_redstone as rst
from torchvision.datasets import MNIST
from torchvision.transforms.functional import to_tensor
# import torch, torch_redstone and torchvision stuff
class MNISTTask(rst.Task):
def data(self):
transform = lambda x: to_tensor(x.convert('RGB')) # grayscale to RGB since the model requires 3 channels
train = MNIST('logs', True, transform, download=True)
test = MNIST('logs', False, transform, download=True)
return train, test # the datasets yield a series of tuples (img, label)
def metrics(self):
# categorical accuracy and cross entropy loss metrics
# call `.redstone` to convert supported `nn.Module` into a `rst.Metric`
# the accuracy defaults with the name `Acc` and the loss defaults with the name `Loss`
# by default, the logits are read from `model_outputs.logits` and `inputs.y`
# however, the inputs are (img, label) and outputs are direct predictions probability logits
# thus, we need a `DirectPredictionAdapter`. (see below)
return [rst.CategoricalAcc().redstone(), torch.nn.CrossEntropyLoss().redstone()]
def main(epochs=10):
# create our resnet18 model
model = torchvision.models.resnet18(num_classes=10)
# place the model onto cuda if available
# data tensors will be transferred to the same device as the model automatically by redstone
if torch.cuda.is_available():
model = model.cuda()
# run a training loop with the model on MNIST task
# the `DirectPredictionAdapter` transforms (x, y) inputs into ObjectProxy(x=x, y=y)
# and wraps the model output into ObjectProxy(logits=output)
# using ObjectProxy is recommended in redstone, as you can easily adapt to multiple models
# and datasets with an adapter
# you can also add type annotations for your ObjectProxy interface so you have better
# development experience with modern IDEs
rst.DefaultLoop(
model, MNISTTask(), optimizer='adam',
adapter=rst.DirectPredictionAdapter(),
batch_size=256
).run(epochs)
if __name__ == '__main__':
main()Sample output:
T 00 Acc: 0.9634 Loss: 0.1191: 100%|████████████████████████████████████████████████████████| 235/235 [00:13<00:00, 17.22it/s]
V 00 Acc: 0.9846 Loss: 0.0477: 100%|██████████████████████████████████████████████████████████| 40/40 [00:01<00:00, 34.43it/s]
New best!
T 01 Acc: 0.9853 Loss: 0.0474: 100%|████████████████████████████████████████████████████████| 235/235 [00:10<00:00, 22.60it/s]
V 01 Acc: 0.9759 Loss: 0.0795: 100%|██████████████████████████████████████████████████████████| 40/40 [00:01<00:00, 37.54it/s]
T 02 Acc: 0.9886 Loss: 0.0369: 100%|████████████████████████████████████████████████████████| 235/235 [00:10<00:00, 22.49it/s]
V 02 Acc: 0.9805 Loss: 0.0601: 100%|██████████████████████████████████████████████████████████| 40/40 [00:01<00:00, 37.07it/s]
T 03 Acc: 0.9916 Loss: 0.0272: 100%|████████████████████████████████████████████████████████| 235/235 [00:10<00:00, 22.66it/s]
V 03 Acc: 0.9889 Loss: 0.0338: 100%|██████████████████████████████████████████████████████████| 40/40 [00:01<00:00, 37.49it/s]
New best!
T 04 Acc: 0.9915 Loss: 0.0269: 100%|████████████████████████████████████████████████████████| 235/235 [00:10<00:00, 22.22it/s]
V 04 Acc: 0.9902 Loss: 0.0307: 100%|██████████████████████████████████████████████████████████| 40/40 [00:01<00:00, 37.52it/s]
New best!
T 05 Acc: 0.9932 Loss: 0.0206: 100%|████████████████████████████████████████████████████████| 235/235 [00:10<00:00, 22.63it/s]
V 05 Acc: 0.9881 Loss: 0.0364: 100%|██████████████████████████████████████████████████████████| 40/40 [00:01<00:00, 37.54it/s]
T 06 Acc: 0.9947 Loss: 0.0176: 100%|████████████████████████████████████████████████████████| 235/235 [00:10<00:00, 22.56it/s]
V 06 Acc: 0.9925 Loss: 0.0257: 100%|██████████████████████████████████████████████████████████| 40/40 [00:01<00:00, 37.41it/s]
New best!
T 07 Acc: 0.9953 Loss: 0.0152: 100%|████████████████████████████████████████████████████████| 235/235 [00:10<00:00, 22.71it/s]
V 07 Acc: 0.9896 Loss: 0.0391: 100%|██████████████████████████████████████████████████████████| 40/40 [00:01<00:00, 37.86it/s]
T 08 Acc: 0.9954 Loss: 0.0145: 100%|████████████████████████████████████████████████████████| 235/235 [00:10<00:00, 22.63it/s]
V 08 Acc: 0.9918 Loss: 0.0263: 100%|██████████████████████████████████████████████████████████| 40/40 [00:01<00:00, 37.37it/s]
T 09 Acc: 0.9953 Loss: 0.0150: 100%|████████████████████████████████████████████████████████| 235/235 [00:10<00:00, 22.59it/s]
V 09 Acc: 0.9917 Loss: 0.0297: 100%|██████████████████████████████████████████████████████████| 40/40 [00:01<00:00, 36.95it/s]