Skip to content

Tutorial: MNIST

北海若 edited this page Feb 27, 2023 · 6 revisions
import torch
import torchvision
import torch_redstone as rst
from torchvision.datasets import MNIST
from torchvision.transforms.functional import to_tensor
# import torch, torch_redstone and torchvision stuff


class MNISTTask(rst.Task):

    def data(self):
        transform = lambda x: to_tensor(x.convert('RGB'))  # grayscale to RGB since the model requires 3 channels
        train = MNIST('logs', True, transform, download=True)
        test = MNIST('logs', False, transform, download=True)
        return train, test  # the datasets yield a series of tuples (img, label)

    def metrics(self):
        # categorical accuracy and cross entropy loss metrics
        # call `.redstone` to convert supported `nn.Module` into a `rst.Metric`
        # the accuracy defaults with the name `Acc` and the loss defaults with the name `Loss`
        # by default, the logits are read from `model_outputs.logits` and `inputs.y`
        # however, the inputs are (img, label) and outputs are direct predictions probability logits
        # thus, we need a `DirectPredictionAdapter`. (see below)
        return [rst.CategoricalAcc().redstone(), torch.nn.CrossEntropyLoss().redstone()]


def main(epochs=10):
    # create our resnet18 model
    model = torchvision.models.resnet18(num_classes=10)
    # place the model onto cuda if available
    # data tensors will be transferred to the same device as the model automatically by redstone
    if torch.cuda.is_available():
        model = model.cuda()
    # run a training loop with the model on MNIST task
    # the `DirectPredictionAdapter` transforms (x, y) inputs into ObjectProxy(x=x, y=y)
    # and wraps the model output into ObjectProxy(logits=output)
    # using ObjectProxy is recommended in redstone, as you can easily adapt to multiple models
    # and datasets with an adapter
    # you can also add type annotations for your ObjectProxy interface so you have better
    # development experience with modern IDEs
    rst.DefaultLoop(
        model, MNISTTask(), optimizer='adam',
        adapter=rst.DirectPredictionAdapter(),
        batch_size=256
    ).run(epochs)


if __name__ == '__main__':
    main()

Sample output:

T 00 Acc: 0.9634 Loss: 0.1191: 100%|████████████████████████████████████████████████████████| 235/235 [00:13<00:00, 17.22it/s]
V 00 Acc: 0.9846 Loss: 0.0477: 100%|██████████████████████████████████████████████████████████| 40/40 [00:01<00:00, 34.43it/s] 
New best!
T 01 Acc: 0.9853 Loss: 0.0474: 100%|████████████████████████████████████████████████████████| 235/235 [00:10<00:00, 22.60it/s]
V 01 Acc: 0.9759 Loss: 0.0795: 100%|██████████████████████████████████████████████████████████| 40/40 [00:01<00:00, 37.54it/s] 
T 02 Acc: 0.9886 Loss: 0.0369: 100%|████████████████████████████████████████████████████████| 235/235 [00:10<00:00, 22.49it/s]
V 02 Acc: 0.9805 Loss: 0.0601: 100%|██████████████████████████████████████████████████████████| 40/40 [00:01<00:00, 37.07it/s] 
T 03 Acc: 0.9916 Loss: 0.0272: 100%|████████████████████████████████████████████████████████| 235/235 [00:10<00:00, 22.66it/s]
V 03 Acc: 0.9889 Loss: 0.0338: 100%|██████████████████████████████████████████████████████████| 40/40 [00:01<00:00, 37.49it/s] 
New best!
T 04 Acc: 0.9915 Loss: 0.0269: 100%|████████████████████████████████████████████████████████| 235/235 [00:10<00:00, 22.22it/s]
V 04 Acc: 0.9902 Loss: 0.0307: 100%|██████████████████████████████████████████████████████████| 40/40 [00:01<00:00, 37.52it/s] 
New best!
T 05 Acc: 0.9932 Loss: 0.0206: 100%|████████████████████████████████████████████████████████| 235/235 [00:10<00:00, 22.63it/s]
V 05 Acc: 0.9881 Loss: 0.0364: 100%|██████████████████████████████████████████████████████████| 40/40 [00:01<00:00, 37.54it/s] 
T 06 Acc: 0.9947 Loss: 0.0176: 100%|████████████████████████████████████████████████████████| 235/235 [00:10<00:00, 22.56it/s]
V 06 Acc: 0.9925 Loss: 0.0257: 100%|██████████████████████████████████████████████████████████| 40/40 [00:01<00:00, 37.41it/s] 
New best!
T 07 Acc: 0.9953 Loss: 0.0152: 100%|████████████████████████████████████████████████████████| 235/235 [00:10<00:00, 22.71it/s]
V 07 Acc: 0.9896 Loss: 0.0391: 100%|██████████████████████████████████████████████████████████| 40/40 [00:01<00:00, 37.86it/s]
T 08 Acc: 0.9954 Loss: 0.0145: 100%|████████████████████████████████████████████████████████| 235/235 [00:10<00:00, 22.63it/s]
V 08 Acc: 0.9918 Loss: 0.0263: 100%|██████████████████████████████████████████████████████████| 40/40 [00:01<00:00, 37.37it/s] 
T 09 Acc: 0.9953 Loss: 0.0150: 100%|████████████████████████████████████████████████████████| 235/235 [00:10<00:00, 22.59it/s] 
V 09 Acc: 0.9917 Loss: 0.0297: 100%|██████████████████████████████████████████████████████████| 40/40 [00:01<00:00, 36.95it/s] 

Clone this wiki locally