Skip to content

Commit a681e9e

Browse files
committed
demo that explains dataloader nesting behavior
1 parent f8a51da commit a681e9e

File tree

2 files changed

+170
-0
lines changed

2 files changed

+170
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,7 @@ Please note that the following notebooks below provide reference implementations
356356

357357
|Title | Dataset | Description | Notebooks |
358358
| --- | --- | --- | --- |
359+
| PyTorch DataLoader State and Nested Iterations | Toy | Explains DataLoader behavior when in nested functions | [![PyTorch](https://img.shields.io/badge/Py-Torch-red)]((pytorch_ipynb/mechanics/dataloader-nesting.ipynb))|
359360
| Generating Validation Set Splits | TBD | TBD | [![PyTorch](https://img.shields.io/badge/Py-Torch-red)](pytorch_ipynb/mechanics/validation-splits.ipynb) |
360361
| Dataloading with Pinned Memory | TBD | TBD | [![PyTorch](https://img.shields.io/badge/Py-Torch-red)](pytorch_ipynb/cnn/cnn-resnet34-cifar10-pinmem.ipynb) |
361362
| Standardizing Images | TBD | TBD | [![PyTorch](https://img.shields.io/badge/Py-Torch-red)](pytorch_ipynb/cnn/cnn-standardized.ipynb) |
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "561ef1d6-e8fc-431f-9c58-4c862b2813ec",
6+
"metadata": {},
7+
"source": [
8+
"# PyTorch DataLoader State and Nested Iterations"
9+
]
10+
},
11+
{
12+
"cell_type": "code",
13+
"execution_count": 1,
14+
"id": "37d57e04-facc-4543-9939-c724a57ce9c6",
15+
"metadata": {},
16+
"outputs": [
17+
{
18+
"data": {
19+
"text/plain": [
20+
"'2.1.0'"
21+
]
22+
},
23+
"execution_count": 1,
24+
"metadata": {},
25+
"output_type": "execute_result"
26+
}
27+
],
28+
"source": [
29+
"import torch\n",
30+
"torch.__version__"
31+
]
32+
},
33+
{
34+
"cell_type": "markdown",
35+
"id": "9fb9f618-f274-4f76-951f-40508cb85c1d",
36+
"metadata": {},
37+
"source": [
38+
"Iterating over a dataloader in a separate function will not affect its state in the main training loop. In PyTorch, a DataLoader is typically an iterable that can be iterated over multiple times independently. Each iteration over the DataLoader starts from the beginning and goes through the dataset in a fresh sequence (if shuffle is true, the sequence will be different each time).\n"
39+
]
40+
},
41+
{
42+
"cell_type": "code",
43+
"execution_count": 2,
44+
"id": "1813f041-0366-4be3-890f-99c1a9c9d831",
45+
"metadata": {},
46+
"outputs": [
47+
{
48+
"name": "stdout",
49+
"output_type": "stream",
50+
"text": [
51+
"main loop: 1\n",
52+
"nested loop: 1\n",
53+
"nested loop: 2\n",
54+
"nested loop: 3\n",
55+
"nested loop: 4\n",
56+
"nested loop: 5\n",
57+
"main loop: 2\n",
58+
"nested loop: 1\n",
59+
"nested loop: 2\n",
60+
"nested loop: 3\n",
61+
"nested loop: 4\n",
62+
"nested loop: 5\n",
63+
"main loop: 3\n",
64+
"nested loop: 1\n",
65+
"nested loop: 2\n",
66+
"nested loop: 3\n",
67+
"nested loop: 4\n",
68+
"nested loop: 5\n",
69+
"main loop: 4\n",
70+
"nested loop: 1\n",
71+
"nested loop: 2\n",
72+
"nested loop: 3\n",
73+
"nested loop: 4\n",
74+
"nested loop: 5\n",
75+
"main loop: 5\n",
76+
"nested loop: 1\n",
77+
"nested loop: 2\n",
78+
"nested loop: 3\n",
79+
"nested loop: 4\n",
80+
"nested loop: 5\n",
81+
"main loop: 6\n",
82+
"nested loop: 1\n",
83+
"nested loop: 2\n",
84+
"nested loop: 3\n",
85+
"nested loop: 4\n",
86+
"nested loop: 5\n",
87+
"main loop: 7\n",
88+
"nested loop: 1\n",
89+
"nested loop: 2\n",
90+
"nested loop: 3\n",
91+
"nested loop: 4\n",
92+
"nested loop: 5\n",
93+
"main loop: 8\n",
94+
"nested loop: 1\n",
95+
"nested loop: 2\n",
96+
"nested loop: 3\n",
97+
"nested loop: 4\n",
98+
"nested loop: 5\n",
99+
"main loop: 9\n",
100+
"nested loop: 1\n",
101+
"nested loop: 2\n",
102+
"nested loop: 3\n",
103+
"nested loop: 4\n",
104+
"nested loop: 5\n",
105+
"main loop: 10\n",
106+
"nested loop: 1\n",
107+
"nested loop: 2\n",
108+
"nested loop: 3\n",
109+
"nested loop: 4\n",
110+
"nested loop: 5\n"
111+
]
112+
}
113+
],
114+
"source": [
115+
"from torch.utils.data import Dataset, DataLoader\n",
116+
"\n",
117+
"# Custom Dataset class\n",
118+
"class IntegerDataset(Dataset):\n",
119+
" def __init__(self, start, end):\n",
120+
" self.data = list(range(start, end + 1))\n",
121+
"\n",
122+
" def __len__(self):\n",
123+
" return len(self.data)\n",
124+
"\n",
125+
" def __getitem__(self, idx):\n",
126+
" return self.data[idx]\n",
127+
"\n",
128+
"# Create a Dataset for integers 1 to 10\n",
129+
"integer_dataset = IntegerDataset(1, 10)\n",
130+
"\n",
131+
"# Create a DataLoader\n",
132+
"integer_loader = DataLoader(integer_dataset, batch_size=1, shuffle=False)\n",
133+
"\n",
134+
"# A function to estimate the loss based on a subset of training examples\n",
135+
"def calc_loss(data_loader, iters):\n",
136+
" for j in integer_loader:\n",
137+
" print(\"nested loop:\", j.item())\n",
138+
" if j >= iters: \n",
139+
" break\n",
140+
"\n",
141+
"# Example: Iterate over the DataLoader\n",
142+
"for i in integer_loader:\n",
143+
" print(\"main loop:\", i.item())\n",
144+
" calc_loss(integer_loader, iters=5)"
145+
]
146+
}
147+
],
148+
"metadata": {
149+
"kernelspec": {
150+
"display_name": "Python 3 (ipykernel)",
151+
"language": "python",
152+
"name": "python3"
153+
},
154+
"language_info": {
155+
"codemirror_mode": {
156+
"name": "ipython",
157+
"version": 3
158+
},
159+
"file_extension": ".py",
160+
"mimetype": "text/x-python",
161+
"name": "python",
162+
"nbconvert_exporter": "python",
163+
"pygments_lexer": "ipython3",
164+
"version": "3.10.12"
165+
}
166+
},
167+
"nbformat": 4,
168+
"nbformat_minor": 5
169+
}

0 commit comments

Comments
 (0)