diff --git a/README.md b/README.md index 5ebc78c..d7aa10a 100644 --- a/README.md +++ b/README.md @@ -78,6 +78,7 @@ intersection2 = compute_psi(dataloader.dataloader2.dataset.get_ids(), dataloader # Order data dataloader.drop_non_intersecting(intersection1, intersection2) +dataloader.sort_by_ids() for (data, ids1), (labels, ids2) in dataloader: # Train a model diff --git a/examples/Simple Vertically Partitioned SplitNN.ipynb b/examples/Simple Vertically Partitioned SplitNN.ipynb index 90cbbb9..e8ffed5 100644 --- a/examples/Simple Vertically Partitioned SplitNN.ipynb +++ b/examples/Simple Vertically Partitioned SplitNN.ipynb @@ -134,7 +134,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 34, "metadata": {}, "outputs": [], "source": [ @@ -142,14 +142,269 @@ "data = add_ids(MNIST)(\".\", download=True, transform=ToTensor()) # add_ids adds unique IDs to data points\n", "\n", "# Batch data\n", - "dataloader = VerticalDataLoader(data, batch_size=128) # partition_dataset uses by default \"remove_data=True, keep_order=False\"\n", - "\n", + "dataloader = VerticalDataLoader(data, batch_size=128) # partition_dataset uses by default \"remove_data=True, keep_order=False\"" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "c1e2ef84-4b0f-4985-ae38-d16d10646588\n", + "c1e2ef84-4b0f-4985-ae38-d16d10646588\n", + "3\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAN2UlEQVR4nO3dXYhd9bnH8d/P10RrQE8mL9jB9FRvxHBi2eqBRomIEnORpDelXhQPBqcalQpeHMkJNHghUU5bXzgK02PsVHsMlVaiGBqjVE0RihPJMVE5vjFaQzQjEbR6EV+eczFLGXX2f4/7be34fD8w7L3Xs9ZeDyv5zdp7/ffsvyNCAL79jqq7AQD9QdiBJAg7kARhB5Ig7EASx/RzZ/Pnz48lS5b0c5dAKhMTE3r33Xc9U62jsNteKel2SUdL+u+I2Fxaf8mSJRofH+9klwAKGo1G01rbL+NtHy3pvyRdKulMSZfZPrPd5wPQW528Zz9X0qsR8XpEHJa0VdKa7rQFoNs6Cfupkv4+7fFb1bIvsT1ie9z2+OTkZAe7A9CJnl+Nj4jRiGhERGNoaKjXuwPQRCdh3y9peNrj71bLAAygTsL+rKQzbH/P9nGSfiLp4e60BaDb2h56i4hPbF8raYemht62RMQLXesMQFd1NM4eEdslbe9SLwB6iI/LAkkQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kERHs7hiMLz88stNa4cPHy5uu2vXrmJ9/fr1xbrtYr1Oa9eubVrbunVrcdvjjjuu2+3UrqOw256Q9IGkTyV9EhGNbjQFoPu6cWa/MCLe7cLzAOgh3rMDSXQa9pD0mO3dtkdmWsH2iO1x2+OTk5Md7g5AuzoN+/KI+IGkSyVdY/uCr64QEaMR0YiIxtDQUIe7A9CujsIeEfur24OSHpJ0bjeaAtB9bYfd9om2T/r8vqRLJO3rVmMAuquTq/ELJT1UjbMeI+l/IuLPXekqmX37yr8jx8bGivUHH3ywae2zzz4rbrt///5ivdU4+iCPs2/btq1p7aqrripue9tttxXr8+bNa6unOrUd9oh4XdK/dLEXAD3E0BuQBGEHkiDsQBKEHUiCsANJ8CeuA2DDhg3F+qOPPtqnTvJoNZx5xRVXFOvLly/vZjt9wZkdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5JgnH0AXHzxxcV6J+PsCxYsKNbXrVtXrLf6E9mjjmr/fPHMM88U60899VTbz42v48wOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kwzj4Arr766mK9NPVwK8cee2yxvmjRorafu1Pvv/9+sX7WWWcV662+Bruk1TE955xz2n7uQcWZHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSYJx9ABxzTPmfYXh4uE+d9NeOHTuK9ffee69n+251TI8//vie7bsuLc/strfYPmh737Rlp9jeafuV6vbk3rYJoFOzeRn/W0krv7LsRklPRMQZkp6oHgMYYC3DHhFPSzr0lcVrJH0+f86YpPY/zwmgL9q9QLcwIg5U99+WtLDZirZHbI/bHp+cnGxzdwA61fHV+IgISVGoj0ZEIyIaQ0NDne4OQJvaDfs7thdLUnV7sHstAeiFdsP+sKTLq/uXS9rWnXYA9ErLcXbbD0haIWm+7bck/ULSZkl/sL1O0huSftzLJnHk2rp1a9Pa6OhocduPPvqo2+184aabburZcw+qlmGPiMualC7qci8AeoiPywJJEHYgCcIOJEHYgSQIO5AEf+KKovvvv79Y37x5c7H+2muvNa0dPny4rZ5ma9myZU1rrb5i+9uIMzuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJME4+wCYmJgo1u+7775i/fHHH+9iN1+2a9euYt12z/Y9b968Yv2WW24p1letWtW0Nnfu3LZ6OpJxZgeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJBhn74O9e/cW66tXry7W33zzzW62c8S44IILivWRkZE+dfLtwJkdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5JgnP0IEBEp9/3II48U69u3by/WS3/PnlHLM7vtLbYP2t43bdkm2/tt76l+OKrAgJvNy/jfSlo5w/JfR8Sy6qf8KxZA7VqGPSKelnSoD70A6KFOLtBda/v56mX+yc1Wsj1ie9z2+OTkZAe7A9CJdsN+t6TvS1om6YCkXzZbMSJGI6IREY2hoaE2dwegU22FPSLeiYhPI+IzSb+RdG532wLQbW2F3fbiaQ9/JGlfs3UBDIaW4+y2H5C0QtJ8229J+oWkFbaXSQpJE5J+1sMej3hLly4t1p988slivdX3xq9cOdNgyZQ5c+YUt+21e+65p2ntjjvu6GMnaBn2iLhshsXN/wUBDCQ+LgskQdiBJAg7kARhB5Ig7EAS/InrADjttNOK9Y0bN/apk+7btGlT0xpDb/3FmR1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkmCcHT21Y8eOultAhTM7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBOPssffzxx01rrcaSL7roomJ97ty5bfU0CLZs2VKsX3/99X3qBK1wZgeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJBhnr+zatatYv/nmm5vWHnvsseK2ExMTxfrw8HCx3kuHDh0q1rdv316s33DDDcX6hx9++I17+twJJ5xQrB/Jn0+oQ8szu+1h23+x/aLtF2z/vFp+iu2dtl+pbk/ufbsA2jWbl/GfSLohIs6U9K+SrrF9pqQbJT0REWdIeqJ6DGBAtQx7RByIiOeq+x9IeknSqZLWSBqrVhuTtLZXTQLo3De6QGd7iaSzJf1N0sKIOFCV3pa0sMk2I7bHbY9PTk520CqATsw67La/I+mPkq6PiPen1yIiJMVM20XEaEQ0IqIxNDTUUbMA2jersNs+VlNB/31E/Kla/I7txVV9saSDvWkRQDe0HHqzbUn3SHopIn41rfSwpMslba5ut/Wkwz657rrrivW9e/e2/dy33nprsX7SSSe1/dyd2rlzZ7G+e/fuYn3qv0d7VqxYUayvX7++WL/wwgvb3ndGsxln/6Gkn0raa3tPtWyDpkL+B9vrJL0h6ce9aRFAN7QMe0T8VVKzX9/lb2UAMDD4uCyQBGEHkiDsQBKEHUiCsANJ8CeufXDXXXfV3ULPLFiwoFhfvXp109rtt99e3HbOnDlt9YSZcWYHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQYZ6/ce++9xfqdd97ZtDY2Nta0VrfTTz+9WG/1dc3nn39+sX7llVcW60uXLi3W0T+c2YEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcbZK2effXaxfvfddzetnXfeecVtN27cWKy3mjZ57dryNHqXXHJJ09qaNWuK2y5atKhYx7cHZ3YgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSMIRUV7BHpb0O0kLJYWk0Yi43fYmSVdKmqxW3RAR20vP1Wg0Ynx8vOOmAcys0WhofHx8xlmXZ/Ohmk8k3RARz9k+SdJu2zur2q8j4j+71SiA3pnN/OwHJB2o7n9g+yVJp/a6MQDd9Y3es9teIulsSX+rFl1r+3nbW2yf3GSbEdvjtscnJydnWgVAH8w67La/I+mPkq6PiPcl3S3p+5KWaerM/8uZtouI0YhoRERjaGioCy0DaMeswm77WE0F/fcR8SdJioh3IuLTiPhM0m8kndu7NgF0qmXYbVvSPZJeiohfTVu+eNpqP5K0r/vtAeiW2VyN/6Gkn0raa3tPtWyDpMtsL9PUcNyEpJ/1pEMAXTGbq/F/lTTTuF1xTB3AYOETdEAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSRafpV0V3dmT0p6Y9qi+ZLe7VsD38yg9jaofUn01q5u9nZaRMz4/W99DfvXdm6PR0SjtgYKBrW3Qe1Lord29as3XsYDSRB2IIm6wz5a8/5LBrW3Qe1Lord29aW3Wt+zA+ifus/sAPqEsANJ1BJ22ytt/5/tV23fWEcPzdiesL3X9h7btc4vXc2hd9D2vmnLTrG90/Yr1e2Mc+zV1Nsm2/urY7fH9qqaehu2/RfbL9p+wfbPq+W1HrtCX305bn1/z277aEkvS7pY0luSnpV0WUS82NdGmrA9IakREbV/AMP2BZL+Iel3EXFWtexWSYciYnP1i/LkiPj3Aeltk6R/1D2NdzVb0eLp04xLWivp31TjsSv09WP14bjVcWY/V9KrEfF6RByWtFXSmhr6GHgR8bSkQ19ZvEbSWHV/TFP/WfquSW8DISIORMRz1f0PJH0+zXitx67QV1/UEfZTJf192uO3NFjzvYekx2zvtj1SdzMzWBgRB6r7b0taWGczM2g5jXc/fWWa8YE5du1Mf94pLtB93fKI+IGkSyVdU71cHUgx9R5skMZOZzWNd7/MMM34F+o8du1Of96pOsK+X9LwtMffrZYNhIjYX90elPSQBm8q6nc+n0G3uj1Ycz9fGKRpvGeaZlwDcOzqnP68jrA/K+kM29+zfZykn0h6uIY+vsb2idWFE9k+UdIlGrypqB+WdHl1/3JJ22rs5UsGZRrvZtOMq+ZjV/v05xHR9x9JqzR1Rf41Sf9RRw9N+vpnSf9b/bxQd2+SHtDUy7qPNXVtY52kf5L0hKRXJD0u6ZQB6u0+SXslPa+pYC2uqbflmnqJ/rykPdXPqrqPXaGvvhw3Pi4LJMEFOiAJwg4kQdiBJAg7kARhB5Ig7EAShB1I4v8B9Tkab9GiCOcAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "print(dataloader.dataloader1.dataset.ids[6])\n", + "print(dataloader.dataloader2.dataset.ids[6])\n", + "print(dataloader.dataloader2.dataset[6][0])\n", + "plt.imshow(dataloader.dataloader1.dataset.data[6].numpy().squeeze(), cmap='gray_r')" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------------\n", + "Id 1 of the DataLoader 1: a6e0ebf1-1687-41d7-a64d-778ed678f532\n", + "Id 1 of the DataLoader 2: a6e0ebf1-1687-41d7-a64d-778ed678f532\n", + "----------------------------------------\n", + "Id 2 of the DataLoader 1: 643b7ca1-d06c-43be-922f-42d3559f2ab2\n", + "Id 2 of the DataLoader 2: 3daa20a8-a5a3-48eb-b73c-d8b31c9cf5ee\n", + "----------------------------------------\n", + "Id 3 of the DataLoader 1: 98330029-1072-47c3-ab85-6c48c0c36ede\n", + "Id 3 of the DataLoader 2: 643b7ca1-d06c-43be-922f-42d3559f2ab2\n", + "----------------------------------------\n", + "Id 4 of the DataLoader 1: 30eaeace-a4a4-4c6c-bcba-a358cf9d2646\n", + "Id 4 of the DataLoader 2: 98330029-1072-47c3-ab85-6c48c0c36ede\n", + "----------------------------------------\n", + "Id 5 of the DataLoader 1: 8fc4db7d-97c9-4432-9e2d-332e2e379773\n", + "Id 5 of the DataLoader 2: 30eaeace-a4a4-4c6c-bcba-a358cf9d2646\n", + "----------------------------------------\n", + "Id 6 of the DataLoader 1: c1e2ef84-4b0f-4985-ae38-d16d10646588\n", + "Id 6 of the DataLoader 2: c1e2ef84-4b0f-4985-ae38-d16d10646588\n", + "----------------------------------------\n", + "Id 7 of the DataLoader 1: f4528b46-9648-4e21-be9f-74aca8a815ca\n", + "Id 7 of the DataLoader 2: f4528b46-9648-4e21-be9f-74aca8a815ca\n", + "----------------------------------------\n", + "Id 8 of the DataLoader 1: 63be02aa-908b-4437-aba9-e637eeab62c8\n", + "Id 8 of the DataLoader 2: 63be02aa-908b-4437-aba9-e637eeab62c8\n", + "----------------------------------------\n", + "Id 9 of the DataLoader 1: 089de33a-437c-4d48-ac96-e29edf4c4a5c\n", + "Id 9 of the DataLoader 2: 089de33a-437c-4d48-ac96-e29edf4c4a5c\n", + "----------------------------------------\n" + ] + } + ], + "source": [ + "print(\"----------------------------------------\")\n", + "for i in range(1,10):\n", + " print(\"Id\", i, \"of the DataLoader 1: \", dataloader.dataloader1.dataset.ids[i])\n", + " print(\"Id\", i, \"of the DataLoader 2: \", dataloader.dataloader2.dataset.ids[i])\n", + " print(\"----------------------------------------\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Check first the datasets \n", + "In MNIST, we have the images and the labels." + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0 4 1 9 2 3 1 4 3 5 " + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAAAqCAYAAAAQ2Ih6AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAZd0lEQVR4nO2da1CTZ9rHf+GQcAYRo4RwErGgUl0LtoDSlqUNVqwiPdjVOsv2ON3u9LD7ZWe37exup7PvbNtpd5bWtmorrseKVosVpRyqNsqiHFROVgPIMUAgCYRDDuT94JAprVareYLvu89vhi95JNdf8jz/+7rv67rvSOx2OyIiIiIirsFtqgWIiIiI/Dchmq6IiIiICxFNV0RERMSFiKYrIiIi4kJE0xURERFxIR7XuT4VrQ2Sq7wm6piMqGMyoo4fc7toEXX8ADHTFREREXEhoumKiIiIuBDRdEVERERcyPXWdG+a3t5eSkpKyM/Pp6WlhSeeeIK1a9cSGxsrVEiR25y9e/fy/vvv09nZSW5uLr/97W+ZNm3aVMsS+Rn09vbym9/8Bl9fX3bt2uXU9+7s7EStVrNv3z70ej0JCQmEhoYyZ84cMjMz8fAQzK5ciiD/i56eHj7++GPy8vLQ6/XYbDbeffddTp8+zYEDB4QIeUM0Nzfzz3/+k507dxIaGsrmzZtZvHjxlOmZCiwWCyaTidbWVqqqqvDz82P16tV4enoKGlen0/H1119TU1PD6OgodXV1aLVal5uuVqvFaDSi0WjYunUrJSUlSCST6x2BgYGsWbOGv/3tb4I+6OPj47S0tLBnzx5mzJjBU089JVgsZ9DX18ebb77JqVOnWLt2rVPfu6SkhLfffpuKigpMJhN2u52ysjLc3d257777iIqKYsGCBU6NOVU4/Y7q7Ozk448/ZvPmzfT09ODp6Ymvry9wxfTOnz9PXFycy0ettrY2tm7dSn5+Pnq9Hm9vb6xWq0s1TCV6vZ7S0lIKCws5f/48bW1tmM1mFAoFOp2O559/XtD4vr6+xMTEEBISQmtrK1ar1aV//7a2No4ePcrBgwdpa2tjYGAArVbL6Ojoj/6tTqfj3//+Nz4+Prz66quO+9fZjIyMsHHjRrZt20Zubi5DQ0P4+fkJEutW6e7u5p133mHbtm14eXmRmprq1PdftGgR06ZNY2RkBLPZDFxJEAC++eYbXn/9dd544w0WLlzo1LhTgdOcz2Kx0NbWxqZNm9ixYwfd3d3Y7XZ8fHxYunQpMTExbNy4kQ0bNvDqq6+yfv16Z4W+IQYHB2lpacFoNLosZnt7O+Xl5Zw4cYK6ujr6+vqIj4/n0Ucfxc3NDa1Wi0ql4o477hBMQ09PD0ePHmXv3r1UV1ej1+uRSCQEBgbi7u5Oc3MzBw8eJCsrC6VSKZgOmUxGaGgogYGBgsX4KbZv305+fj6XL1/GarVis9mw2WwAP8p0x8fH0Wq1bN++nYcffliwB91utzM4OIhWq6W7uxuj0Xjbmu7g4CCnTp1idHSU1NRUsrKynPr+wcHB5OTkoNFoqKysxMfHBzc3N4xGI0NDQ9TV1XHmzJnbxnQtFgs9PT0MDAzw1VdfsX//fiQSCenp6bz55ps/+btOM929e/eydetWqqurGRgYcNzQJpOJ3t5eUlJSiImJobm5mdraWpeabm9vL8XFxZSWluLl5UVSUhKvvfYac+fOFSxmbW0tH330ESUlJfT09ODv709wcDA1NTVUVVUB4OXlhdFo5M9//rPT44+MjFBQUEBBQQG1tbX09vYyPDxMQEAAK1eu5Mknn6SiooI33niD9vZ2dDqdoKY7NjZGV1cXBoNBsBg/RVhYGHa7nZGREXx9fUlMTCQ6OhoANzc39Ho9Z8+e5eLFiwDYbDaMRqMj6xISu93uGAimisuXL1NYWEhoaCjLly/Hy8vLcW1gYIDS0lJaWlq44447ePnll50+OEgkEh588EHUajVNTU14enoikUgcSdKMGTNITk52asybQa/XU1FRQWlpKeXl5RiNRvR6PTqdDjc3N8bHx11juh0dHRw5coSTJ08yODjIrFmzuOuuuzCZTJw8eZJp06axePFiDAYDH3zwwVWndELR19fHvn372Lx5M1qtloiICNatW0dqaqog65jj4+NcvnyZDz/8kAMHDiCTycjOziYrK4vw8HAqKirIy8ujqakJhUJBUFCQ0zUAVFdXs23bNtRqNaOjo4yPjxMQEEB2djYvvfQSERERtLS0AFdGbaEf+AnTnXiIdDodLS0thIeHuyT7ValUKBQKDAYDUqmUsLAwgoODHddHR0cpLi7mtddew2g0OmZoc+bMEVwbXLlvpnK564svvuDjjz/mwQcfZMmSJYSFhTmutbe3s3XrVmw2Gy+88IJg5ufv749KpaK5uRm1Wk1/f7/jmqen55TOAmpra9m7dy9NTU2cO3cOnU6HXq/H39+f9PR04uLiSE9Pv6F7+ZZNt7Ozk3feeYeSkhKGh4eJjIxk/fr1pKSkUF9fT2pqKiqVioSEBCQSiaOg1tLSQlRU1K2Gvy7fffcdRUVFNDU14ePjQ3JyMg8//DBSqVSQeGfOnCEvL4+ioiJ8fHzIzc3lkUceITIykv7+fvr7++nu7kYqlRIVFUVGRoYgOo4fP05jY6Mju42MjESlUrF27VrmzZvHwMAAfX19gsS+Gr6+viQkJKBUKhkYGKC2tpZdu3YREhLCPffcI3h8uVzOsmXLGB8fRyKR4OHhgbu7u+N6T08Po6OjjsFHKpUyZ84cwQbFH9Lf349Wq3Vk365Gp9PR09ODXq9nZGTE8frw8DBNTU00NDTg5+dHaGgoPj4+gum45557mDZtGlu2bGHXrl3o9XrgivEXFRXxzDPPCBb7apjNZs6cOcOmTZsoLi7Gzc0NhULBE088QVhYGIGBgcyfP5/AwEBmzZqFm9v1u3BvyXT7+vooLCykqKgIvV5PSEgIS5cuRaVSsXDhQubNm4ebmxuhoaF4eHgQHByMxWLh0qVLFBQU8Pvf//5Wwl8Xs9lMdXU1arUaiURCbGws2dnZzJgxQ5B4R48eZePGjZSVlREdHc369evJzs4mMjISNzc3vv32Ww4dOoTRaCQiIoL169cTExMjiJbU1FS6urrw9PQkNjaWqKgoYmNjUSqVeHp6Mjw87FLT9fDwQKVSUV5ezrlz59Dr9bS2trp0jf1aA61Wq3UsP03MwgIDA8nJyfnReq8zcXNzQyqV4unp6SjsTQVqtZqqqirc3d1RKpWOgWZ8fByNRsNXX32FzWYjLi5O8DVVrVZLW1sbOp3uRzPiGzE0Z3Pu3Dm2bNlCdXU1WVlZZGRkMGvWLKKjo/H398fDw2PSUsyNcNOma7PZ2L9/P/n5+XR1dbF8+XLS0tKIi4tjzpw5+Pv74+/vf9XfHR4e5vTp0zcb+oYwmUyUlpZSUFBAb28vkZGRZGVlkZaWJki80tJS3nvvPdRqNXK5nNzcXNasWUNoaChwpfpbWVlJY2MjcrmcNWvWCNqqlZSUhFwuRyaTERwcjI+Pz6TMzmAw0NbWJkjsa6FQKAgODmbi20qm+ltL2tvbOX78OOXl5dTU1KDRaLDZbI5ZyLx58wSNL5PJCA8PRy6XYzabGRsbEzTe1RgaGuLLL7+kurqaJUuW8Mtf/tJhui0tLezYsYPS0lJiY2P505/+JNi6/8DAAAcOHKCoqIjm5mZaWlommW5ISAh33XWXILF/iNVqpba2lhMnTjAyMkJkZCTz589HpVIxd+7cSc/RzXDTptvZ2UlBQQFVVVVkZmby3HPPsXjxYmQy2XXbwWw2G8PDwzcb+oZoa2ujsLCQ06dPExgYSFpaGmvWrGH69OlOjzU2NuZYP5XL5Tz//PNkZ2ejUCgYHh5GrVbz1Vdf8c033yCTycjIyGDDhg3MmjXL6VomkMlkkwqFGo2G1tZWx/S5sbGR5uZmvL29iY6OnrSG5yqEzCJ/yODgIPX19TQ0NDA0NATAhQsXOH78OBqNhuHhYSQSCWFhYahUKtLT0wVbgprA3d2doKAgwVrSrodOp+Pzzz/n8OHDBAcHs2rVKhYvXoyHhwddXV3s37+fXbt2YTabWb16Nffdd59gWtrb29m/fz9Hjx69as1nZGSErq4uFi1aJJiGCaqrq9m0aRP19fVkZ2eTk5NDYGAgISEhTsm2b9p0v/jiC86ePUtQUBBPPPEEycnJN7TW44osp6uri0OHDnH8+HHc3Ny47777WL9+vWCtWSMjI1RVVWEymVi9ejUrV65Er9dz6tQpLly4wLFjx6ioqMBoNHLnnXeSlZXF/PnzBdHyfWw2G3q9ntraWgoLC6mrq3P0Pg4NDdHW1kZMTAy/+tWvkMvlguuBK0Y78eMq+vv7+fzzzzl69ChNTU2YTCbgSiXaYDA4tAQEBJCSksJLL71EdHT0LWc0Pwer1eqSTgm48uxduHCB3bt3s2vXLjQaDQkJCQwNDaHT6ZBKpZw7d47CwkL0ej2ZmZk8+uijgmqaNm0a8+fP5+zZs44C7/fR6XTU1tayfPlyQXUAHDx4kAMHDpCYmEhiYiIxMTFOXdq4KdO9fPkyBQUFDAwM8MADDxAfH/+zFtfd3d0FyTjhigEePHiQbdu2cenSJSIiIrj33ntZtmyZYFP57w8gPT097NixA41Gw9mzZ9FqtYyMjDA6OoqHhwezZ88mJSVF0M0hdrud/v5+ysrKqK2t5fTp05w/fx6bzYbdbqevrw+bzYZMJiM+Pp7Zs2djt9tdaoSupL+/ny+//JLy8vJJM6wfDvwTRbbg4GCXZ5+9vb10dHS4JNbFixfZuHEjO3fuRKvV4ufnR0dHBzt27KCjo4Po6GgqKipobGxk3rx5PPPMM8TFxQmqSalUsmHDBmQyGTU1NY6llo6ODurr67FarZMKfEIysWtxcHDQ0ebozJ2TN/XkNzY20traikQiYdmyZTeUJRkMBurq6pBKpcyePZuHHnroZkJfl/Pnz1NQUMDZs2cBCA8PJyEh4Wcvdv8cvL29UalU9Pb2sm3bNnx8fAgNDWX27NksXrwYrVbLmTNnAEhISCA8PFwwLePj43R1dbFv3z62bNnC6Ogo0dHRrFixAqVSSVNTE4WFhRiNxknbUBUKBRERES4pVkyYnclkclSnhSQgIIC0tDTMZjPu7u6O9e0JHQaDgaamJtra2qivr0ej0aBQKATXBVcKdn5+frS2ttLZ2Sl4vMuXL7Np0yZ27tyJwWAgOTmZuXPn0tTURH19PRcuXHD0j1utViQSCWazGZPJJPhAFBcXx9NPP01PTw8WiwW73c7Jkyf58MMP6evro7+/n5GREby9vZ0ee2BgAB8fH8fyX29vLzU1NWzatAlPT0+WL1/utLg3Zbpmsxmr1Yq/vz933XXXpH7Hq6HX6ykrK+Ozzz5jxowZZGdn8/DDD9+U4J9iaGiII0eO8N133+Hu7o6/vz/33HMPiYmJTo/1fby8vHjxxRfx8vKipaWFgIAA4uLiSE5OJjg4mPz8fE6fPo1CoRB0WcFut1NfX8/Bgwf56KOPGBkZ4emnn2blypXMmTOH7u5uPvnkEzw8PAgKCkIulzMwMMDu3bvx8/NzzFomGtOFaJf6fnbZ2dlJTU0NWVlZgrYhyeVynnvuOe6++24CAgKYPn36pMGlvb2dzz77jE8++cTl2b5SqUShUHD+/HksFgtjY2PIZDLB4h04cIA9e/Y41q7Xrl3L3LlzOXbsmOM+nRgIJRIJ58+fZ9euXSQlJbkk+w8LC3PUF+x2OwaDgcDAQDo6Oujs7MRgMDjVdMfGxqioqKC2tpYlS5aQkJDAqlWrUCgUvPXWW3z99ddERESwZMkSpxURb2mOO3PmTKZPn37NqbLFYkGr1fLtt9+ydetWurq6eOyxx3j00UcFubGOHTtGWVkZfX19eHt7k5qaesMNy7dKREQEf/3rX3/0+sWLF9FoNGi1Wu68804iIyMF03DhwgXy8vIoKCggKCgIlUpFTk4OsbGxaDQaNm/eTEFBAX5+ftx///0kJydTV1fH8ePH2b59OydOnCAzMxNfX1/8/PzIzc11usbg4GCCgoIwGo309vZy6tQpGhoaBK9MBwYGcu+99171WkBAAPHx8YLGvxa+vr74+PgwPj6OyWTCaDQK1tII0NTURGpqKitXrkSlUjkGVqlUSkNDA42Njfj5+REVFYW3tzdyuZwHH3zwuomVEOh0Or799lsqKysJCQkhPDzc6VlucXExb731FjabDX9/f2JjYwkODna0zslkMry9vZ06GN+S6SqVymtO2yf2SxcVFXH48GHMZjMvvfSSoAerfPrpp6jVaqxWKwsWLGDdunWCVlxvBKlUilQqRSKREBoaKtjD3dnZyfvvv8/u3btRKBQ8/fTTpKen4+XlRXl5OXv27KGyspLIyEhWrlzJunXrUCqV9Pf3U1tbS0lJCUVFRbzzzjt4eHiwbNkyQUx34cKFxMfHU1FRAVwpelZVVQliuhOZkq+v7zXX800mE1VVVajVaqfHvxH8/f0JCgrCw8MDg8GAVqsV1HQfeugh5s+fT1hYmCNZmjhr4uLFiwwNDbF69WpefvllQkJCkEqlghiu1WpFo9Hg5+fHjBkzfvT52Gw2Tp48yZEjR4ArGXBmZqbTE6gPPviAhoYGVqxYQUBAAOfOnUMqlVJeXk5dXR0LFixg6dKlTv1Mbtp07XY7lZWVXLhwgVmzZuHl5YXFYsFgMDA4OEh1dTU7duxArVYTHh7Os88+K6jhdnd3o9PpGBsbQyKRkJSU5Gh/mUrkcjkhISGCx8nPz2fnzp1YLBZWrFhBfHw8NTU1lJSUcOzYMYaGhsjMzCQ3N5fExEQCAgKAK3vaMzIySE5O5u6776a8vByZTMbvfvc7QXRGREQQGRnJmTNnsFgsjI6O0t/fj8VicWqh02Qy0dnZyTfffMMDDzyAUqmc1I0wPj7OyMgIp06dIi8vj8OHDwNXptSu7FqIjo4mMTGRo0eP0t3dzaVLlwQ9wvBqtZSenh5KS0tpaGhAqVRy//33Cz7zaGxs5I033mDevHls2LCB8PBwpFIpVquV0dFROjs7KSkpobq6Gi8vL2JiYgTRlJubi06no7i4mIKCAmw2G9OnT8dutzNz5kwef/xxp7cP3rQjSSQSDAYDeXl5dHd3ExsbS3t7O2VlZVRWVtLR0YFUKiUxMZHMzExUKpXTRP8Qs9lMXl4eDQ0NwJUp25w5cxwbE6aSixcv0traKnicnTt3YjKZcHd3Z9++fezduxej0YhEIsHX15fs7GyeeeYZEhMTrzpV8vX1ZfXq1axevVpQnYsWLSI1NZXy8nK0Wi0tLS0cOXKEJ5980qnFq9LSUj788EOOHTvG5s2bWbp06aQHx2AwUFtby/bt2zl8+DDj4+MEBwcTERFxzU09QqFUKpk9ezbx8fGsWLHCpbFtNhvFxcXs2LEDvV7PK6+8wmOPPSZ43N27d6NWq/nPf/6DTCZj8eLFREdH09fXR2NjI2VlZZSUlGCxWPjFL37B2rVrBakxTJz4V1NTw5kzZxwF5vDwcB555BEyMjKcHvemTFcmk+Hp6Yndbqe4uJiKigoCAwPR6/WYTCbH1tOJI+CE6lSYQKPRoFar0ev1yGQyMjMzWbp0qaDFmRvFZDI5+kKFJC4ujv7+fkevpa+vLykpKSQnJ5OWlsb8+fMJCAi4LdrCFAoFSqWSnp4eQJhNEu+++y6nTp3CbDazfft2ysvLJ2Wwly5dorq6mt7eXkdHzcqVK8nJyZmyw7Injpt05ezs0qVLHDp0iI6ODtLT00lLS3PJwTKPP/44TU1NqNVq/vKXvxAUFER0dLSjdW6in1ypVLJq1SpWrVolWMtnTk4OOTk5grz31bipT3fhwoXce++9FBUVodPpMBgMGAwGZDIZs2bNIikpiWeffZbMzExn670qJpOJ/v5+xsbGmDlzJo899hgpKSkuiX09oqKiHC1idrud8fFxQeJ8+umnlJaWUlNTg4+PD5mZmURHR+Pl5XVbGO33WbhwIUlJSdTW1gp2spa/vz9eXl5YrVYOHTo06Zqbm5vjRy6Xs3DhQtatW8fatWsF/waNq2G1WhkeHnYc8u+q7a4A7733HkeOHGHRokU8++yzLjl8CGDBggWkpaWRkZHBBx98wKVLlzh37hxms9nxjAQFBfHkk0/y+uuvu0STq7gp05XL5bz99tukpKTwr3/9y7GH//777+fXv/41S5cudenXsEw8QLebucCVDo958+Yxc+ZM9Ho9HR0dgmScPj4+ZGVlOf1waSGYMWMGCQkJxMTEOJrznd0f/Pe//538/Hzy8/Mdm0EmThYLDQ1FLpcTHR1NZmYmDzzwwJQuRX333XdcvHiRzMxMl/UHT5CRkcHZs2dJSkpi9uzZLo394osvYrFYmDZtGl1dXbS2tlJaWkpzczNeXl48/vjjvPDCCy7V5BLsdvtP/UwFP1uHTqezv/jii3alUmmPioqy79u3b0p0XIuamhr7unXr7Eql0v7KK6/YNRrNlOi4Rf7P6TCbzfY9e/bYk5KS7EFBQfb09HT7W2+9Za+srLSbTCaX6bgeBQUF9hUrVtj/+Mc/2vv6+pyh47b/bP6LdSCx//QZCFNxDNTVUsD/0zrMZjMFBQX8z//8DwaDgaeeeoo//OEPN7pL7v/d3+MWEXVM5lpTpttFi6jjhy+KpntNnKrDbDZz4sQJ/vGPf2A2m9myZcuNbpT4f/n3uAVEHZMRTffH3M46RNP9CUQdkxF1TOZ21gG3jxZRxw9fvI7pioiIiIg4Edd//4WIiIjIfzGi6YqIiIi4ENF0RURERFyIaLoiIiIiLkQ0XREREREXIpquiIiIiAv5XwO2vnGmvcdmAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# We need matplotlib library to plot the dataset\n", + "import matplotlib.pyplot as plt\n", + "\n", + "# Plot the first 10 entries of the labels and the dataset\n", + "figure = plt.figure()\n", + "num_of_entries = 10\n", + "for index in range(1, num_of_entries + 1):\n", + " plt.subplot(6, 10, index)\n", + " plt.axis('off')\n", + " plt.imshow(dataloader.dataloader1.dataset.data[index].numpy().squeeze(), cmap='gray_r')\n", + " print(dataloader.dataloader2.dataset[index][0], end=\" \")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## So, lets implement PSI and Order the datasets accordingly" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ "# Compute private set intersections\n", "intersection1 = compute_psi(dataloader.dataloader1.dataset.get_ids(), dataloader.dataloader2.dataset.get_ids())\n", "intersection2 = compute_psi(dataloader.dataloader2.dataset.get_ids(), dataloader.dataloader1.dataset.get_ids())\n", "\n", "# Order data\n", - "dataloader.drop_non_intersecting(intersection1, intersection2)" + "dataloader.drop_non_intersecting(intersection1, intersection2)\n", + "dataloader.sort_by_ids()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Check again if the datasets are ordered" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1 4 0 0 6 4 2 9 1 0 " + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAAAqCAYAAAAQ2Ih6AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAbzklEQVR4nO2de1BUV7aHv27oBgRsFASVtwZFUTSiAoIgikZRMTAadSbGyUSjMUlFpyZxqpKKlbrJpJKqydOok5go6iSKLzRoEFQagoKID0ABUZCXCMr72U13c+4fFn19C9KH8dacr4o/oHf3+nHOPmvvvdbau2WCICAhISEh0TfI/9MCJCQkJP6bkJyuhISERB8iOV0JCQmJPkRyuhISEhJ9iOR0JSQkJPoQ8ye8/p8obZA95G+SjnuRdNyLpONBnhUtko77kGa6EhISEn2I5HQlJCQk+hDJ6Ur8R6mrq2Pr1q14enoSEBDA1atX/9OSJCRERVSnq9friYuLw8/Pj1WrVtHc3CymuWeWwsJC1qxZw/jx49m4cSPnz5/ns88+Izc3t091pKWlER0djUqlQqVSYW9vz7x587h27Vqf6gAQBIGsrCzeeust1q1bR0NDA6NHj2bIkCF9rqWvEQSBtLQ0Xn75ZVxdXY33Q6VS4ebmRlRUFNu3b/+vfV7upr6+nk2bNjF8+HBefPFFioqK+sy2wWCguLiY9evXM2rUKPz9/YmIiCA+Ph6DwfD0HywIwuN+ekVHR4fwyy+/CE5OTsKwYcOEzZs3d+dtJtfxlJhEx82bN4X33ntPUKlUgpOTk7B06VLhnXfeEWxtbYXg4GChoqKiT3Rs3bpV8PX1FSwtLQXuJBUEmUwmWFtbC7NmzRIKCwv7REcX+fn5wurVqwUnJydhxIgRQmBgoJCRkdGdt3ZLR25urrBu3TqhqqqqW3quXLkifPPNN8K+ffu6+R883fWor68Xvv76a+O9sLCwEKytrQUbGxvBxsZGsLa2FiwtLYX+/fsLoaGhQnV19dPo+H/9zNxNXl6esGLFCsHc3Fzw9vYW9u/f3yc69Hq9cPDgQWH69OmCjY2NYG5uLnh4eAjh4eHC5MmThZiYmKfVIe5M19zcHEdHR5ydnampqeHUqVNimruHyspKSkpKaG9v7zOb99PW1saePXs4ePAgzc3N9OvXD29vb0aPHk1rays3btzgzJkzomqoqalhw4YNfPLJJxQUFKDVanF0dGTYsGFGjadPn+a9994TVcfdVFdX8/PPPxMXF0dYWBjbt2/n888/Z8KECSa1c/LkSX744Qdqamq61b65uZmysjKTarifxsZGjhw5wpUrV/D392fz5s2cO3eOgoICCgoKSEpK4t1338XPz4+MjAzmzp2LWq0WVdOzTGFhIWq1ms7OTjo7O9Hr9X1i9+DBg/zzn//k9OnTtLW1ER4ezvbt2/nkk0+YOXMmmzZt4uTJk0/12U8qGesVZWVlJCYmkp+fz8CBAxk9erSY5gDQ6XQcO3aMrVu3otFoeP3114mOjhbd7v20tLTwr3/9i61bt1JaWoogCFhbWzN48GD69esH3HnI09PTRdV3+vRpEhMTKS8vx9LSksDAQP70pz8xbtw4Dh48yGeffUZHRwe3b98WTcPdNDc3s3v3bnbu3Im7uztLly5l4sSJCIKAQqEwmZ0RI0Ywc+ZMsrKyaGpqwsHB4bHtBUGgpqaGlpYWk2l4GJ2dnWg0GoYOHcqKFSuIiorCysoKufzO/MfR0RFvb29mz57NP/7xDzIzM0lPT2fatGmiacrIyODMmTNkZ2eTl5dHfX09zs7OTJs2jbCwMCZNmoSlpaVo9h9HbW0tZWVlKJVKnJ2d+8SHfPHFF/z000+0trby4osv0tDQwNChQwkMDEQmkyEIAtnZ2aSmpjJ9+vQef76oM93c3FwOHTqERqPB2dmZJUuWiGkOgBMnTrBx40ZOnDjB6dOnuXTp0mPbd3Z2UllZyYEDB0yq49q1a6SmpnL9+nX0ej2+vr688cYbzJs3D1dXV3x8fGhqakKtVouWPPrqq6/46KOPuHTpEgaDgeeee45ly5bx0ksvMWLECNzc3IA716C1tVX0WR7cecDj4+PR6/XMnj2b6dOno1AoUCqVJrWjVCpxcnKio6ODzs7OJ7YfPHgw3t7elJWViep4zczMCA8PZ9asWQwfPhxra2ujwwVQKBQMGDCAiRMn8vbbb6PT6di7d2/vYoiPIT4+ns8//5wvv/yS/fv3c/78eQoLCzl9+jTfffcd77//PrGxsaLYfhJZWVn89ttv6HQ6FAoFgwcPxsvLSzR7tbW1fPPNN2zfvh1bW1s+/PBD3nzzTSZNmsTo0aNRKpUoFAq8vLx44YUX2LdvH0lJST22I9pMt6ioiKSkJIqLi3FycmLWrFnGh1wskpKS+O6770hPTwcgNDSUKVOmPLK9wWAgPz+fv/71r1RXVxuX/r1Fr9eTmJhITk4OHR0duLu7s2TJEhYvXsyAAQOQyWQsWrSI//mf/6GkpIQ9e/bwwQcf9NpuF7du3SI5OZk9e/aQl5eHVqtl7NixrFy5kvnz56NSqdDr9YwcORJfX19ycnK4fv06GzZsYNu2bSbTcT/l5eUcPXqUa9euMX/+fJYvX46NjY1o9szMzLrd1tbWlqFDh1JXV8eZM2eYMWOGKJqcnJx47bXX6OjoeOzsW6lUMnjwYOBOMkkw8RGsnZ2dHDp0iG+//ZZz586h0+lwdHQkNDQUd3d3KioqyMjI4OzZswwcOJAFCxagUqlMquFx6PV6KioqKC0tBcDFxYV58+ZhYWEhmr0jR46wc+dOvL29eeWVV5g6dSpKpRIPD497+pJKpSIoKIi9e/eyf/9+Zs6c2SNbojjd1tZWjh8/zq+//grAyJEjmTVrVo8egp6SmZnJli1bSEtLQxAEZs+ezWuvvYa/v/9D2wuCQFVVFZs2bSI1NRWZTEZMTAyfffZZr7WkpaWRkJBAVVUV/fr1Y9asWURHR2Nvbw/AoEGDCAoKwtHRkZaWFqqqqnpt825OnDjB5s2buXTpEkOGDCEqKopZs2bh6+vLoEGDgDvxdjs7O6Om5uZm1Go1arVatKXs8ePHOXHiBLa2tvj7++Pu7i6KnS560t/kcjl2dnaoVCrOnz8vmtO1sLBg6NChT2yn1WqN/UKMgampqYmdO3eSmZmJq6srUVFRTJ06FVdXV2xsbEhPT6eiooLq6mqKi4u5cOGCqCGO+7l48SJxcXEUFhYac0MjRowQzV5sbCxbtmzBy8uLV199lcDAQON1d3FxuaetmZkZ9vb2uLq6cvbsWSoqKh5o8zhECS+cOnWK2NhYKioq8PLy4uWXX2b8+PFimAIgOzubr776CrVajUajISQkhOXLlxMcHIytre1D31NXV8euXbs4ePAgMpmMqKgoXnzxRZPo+f3337l69SodHR2MHj2a8PBwPD09ja+bm5vTv39/UUbtrnj2+fPnkcvljB8/nsjISKZNm8bgwYORyf5vZ6JcLjc6JkEQRE1SlJeXk5KSQnV1NcHBwUyZMuWeZbUYjB07tkexSDs7O4YMGUJmZqaIqp6MwWAgNzeXL774ArlcTmhoqMknLBkZGeTn5+Pi4sKaNWt4/fXXCQ8PZ8yYMTg4ONDe3m5MQLa2tlJSUmJS+0+ipKSE7OxsmpqaUKlUjB07VrTQwt69e9m4cSO3b99m/vz5BAQEPHGgU6lUhISEcPPmTS5cuNAjeybv9VVVVSQlJZGVlYWlpSWTJ09m/vz5j3R+vUWn0/HLL79w8uRJWlpaCAsLY8WKFQQHBz/ywnVl7H/88UcaGhoICQlh/fr1BAYG9lrP9evXyczMpKamBisrK8LDw5k6dapJk0SPY+fOnWRlZdHW1kZkZCQrV65k7NixWFhY3ONwAezt7Xn++eeNv+t0OlESap2dnRw7dozMzEwEQaC9vZ3z58+Tm5uLVqs1ub0unJycuHXrFjU1Nd0aUJydnQkODu52tYNY3Lp1i6NHj5KSkoJCoWDMmDEP3LveUlZWhkajITo6mgULFhhXHQUFBWzZsoXvv//eONM2MzMTbVn/MG7evElWVhYlJSUIgoCDgwPjxo0TJbxRUlLCtm3baG9vZ9myZQQFBXXLV9nY2BAYGMjw4cNJS0vrVt6gC5OGF9rb2zl+/DhqtRqdTkdISAhLlizBycnJlGaMCILAzz//zOHDh6mvrycsLIxVq1Yxbdq0x96gmpoaYmNjKS0tZcyYMfz9739n3LhxJtF05MgRLl++bIyj+vv7G2NzYpOWlsbZs2dpaWnB3d2dyMhIQkJCsLa2fmj7rrrBu38Xwwk2Nzfz+++/c/36dQCSk5O5fPkynp6eBAUFMW3aNHx9fU1uV6lUUltby9mzZxkxYgQDBw58bHs7Ozuee+45mpqaqKys7FYYwNSUlpby888/c+DAAfr378+yZcsIDg4WzV5nZyeXLl2isLDQmEBLT0833isLCwtcXV1FXaneT3l5OQUFBTQ2NuLo6EhYWBghISEmt9OVOMvPz+e1115jyZIl3b7nZmZmDBkyhMDAQNRqNU1NTdjZ2XXrvSZzuoIgkJ6eTmxsLAUFBXh5ebFw4UKmTp1qKhMPkJqayubNmykuLsbb25tXX32VsLAw+vfv/8j3NDU1kZqaSnJyMkOGDGHdunWEhYWZRI9er+fcuXPU1taiUCiYPn06Y8eOfaCdTqejrq6OtrY2kyVImpub2bJlC1VVVbi5ubF48WL8/PyM5Wn301Ui1VXdIZPJsLa2ZtKkSSbRczdZWVnk5+ej0WhQKpW0tbVRXl7O1atXyczMNO7YM3U50KBBgxg+fDgXLlxg7ty5T3S6MpkMMzMzYzItKirKpHoeh06no6qqil9//ZXt27dTU1PDnDlzWL16tbGm2pSMHj0aGxsb4uPjyczMRCaTUV5eTllZGdbW1ri7u1NfX4+ZmRkjR47kueeeM7mGh1FVVYVaraagoACZTMaoUaNYtGiRKKGFvLw8Dh48yMyZM/nDH/6Ah4cH5ubdd4mWlpaMGjWKHTt20NbW1vdONz8/n127dpGWloalpSUzZswgPDwcKysrU5l4gAMHDhgL/t3c3GhtbSUxMRFzc3M6OztpamqiubmZjo4O4M6SQC6Xs3v3bvR6PcuWLWPhwoUm01NVVUVFRQVarZahQ4cyZcqUByo2BEGgsrIStVpNQ0MDNjY2eHh49Nr2+fPnUavVNDc3ExAQwIIFC3BxcXnkslSr1VJWVsbly5eB/9vIMnLkyF5ruRtBEEhKSqKkpAQ7Ozv8/f0JCAjA2tqaGzdukJWVRVJSkskqR+5GpVIRFhZGQkLCE7fUGgwGYwhCqVRy7tw5wsPDsbCwMHk52/3cvn2brKwsUlJSOHPmDNbW1sycOZNFixbh6elp8tACgI+PD97e3iQkJJCXlwfcmdV6enoaJyFxcXHY2Ngwfvz4PguPZWRk8Ouvv1JcXAyAu7s7fn5+olyD9PR0GhoaiIqKYtiwYT1yuHBntjtgwIAe2zWJ0+3s7CQ+Pp7k5GT0ej3h4eFER0ebxJk8josXLxp3nF25coUtW7bQ0tKCUqnEYDBQW1tLfX09Go0GuLN8HDx4MOXl5fj6+rJq1SqTFX0bDAaSkpIoKiqio6ODiRMnMnz48AdupF6vJz8/nwMHDiCTyRgzZoxJMuW5ubm0trYC4O3tjaur62OdRV1dHRcvXqS8vByAfv36PbLSozdUVVWRk5NDfX09M2fOZM2aNYSEhNCvXz/q6+vZt28fX3/9tbE0yNSMHz+eHTt2UFNTg8FgQKvVUl1dTWtrKx0dHTQ2NtLc3ExraystLS1UVFSg0+k4ffo0e/bsISAggDFjxoiirampiaKiIlJSUkhISKC0tJRRo0bx5z//mdDQUNzc3ERLNvbr14+lS5dib2/P9evXMTMzw9XVlUmTJjFixAhOnDiBubk5fn5+hIaGiqLhfhoaGkhJSeHy5cvo9XqcnZ3x8fF57Mr1aWlrayMpKQlLS0vs7e2fOlFpMBgYMGBAjwZmkzjdgoIC1Go1lZWVTJgwgT/+8Y8mSUo9CaVSiVwuRyaTce3aNczMzHBycmLQoEG4uLjg7e1tbKvRaLh+/Tr5+fmoVCq8vLxMOii0t7ezf/9+qqqqsLCwwN/f/4FZrlarpbCwkMTERK5evcqQIUNYsGCBSeLJFRUVGAwG7OzsCAkJeewIrNVqycvLIzU1Fbgzy3Vzc2PZsmW91nE/169fp7q6GhsbGyIiIggJCTE+RF0dValUipZodXR0BO4M0AMGDKCwsJALFy7Q0dGBjY0NBoOBtrY22tvb0ev16PV6lEolRUVFFBUV4ePjI4qu27dvc+LECeLi4sjJyUGlUhEVFUVkZCSTJk0StbwS7sxqo6OjmTp1KqWlpSgUCtzd3ZHL5Rw+fJj4+HhUKhVz5841+ernYWg0GhITE0lNTaW+vh6VSkV4eDhz5swRxV59fT0XL15k+PDhDBw48Kmud1tbG0VFRYSGhvao//ba6dbV1RETE8OFCxdQKBSEhob2SacBWLNmDWZmZlRWVmJlZcXAgQOZPHky48aNw9fX954Tq4qLi9m0aRM3btwgNDSU8PBwk2rp6Ogwhhbc3d3x8vK6J5nX1tZGTk4Ou3btYt++ffTv35/AwEDCwsJMMptRq9VotVocHBxwcnJ67MhbVlbG0aNHOX/+PDKZDHt7eyIiIpg4cWKvdTyKUaNG3TNr0el0ZGZmsnfvXgDRtncOHToUX19fTpw4QVFREbdv36azsxM3Nzc8PT0ZO3YsDg4OKJVKBgwYQHNzMz/++CPHjx/n448/Nnk/bmpqoqGhgYSEBLZt20ZJSQnjx49n6dKlzJ492zhI9BWDBg0y1m7rdDpSU1P597//TV5enuh94m4aGxvZvXs3ly5dMsZyIyIiRHX4MpkMlUr10MqeJ6HRaMjJyeH48eO8++67Paru6LXTTUtLIykpiZqaGiZMmEBwcHCfHc8XFRWFm5sbZWVlDB48GA8PDxwcHB6IP+n1em7cuEFGRgbDhg1jw4YNomVjBUEwLu8VCgWCINDQ0EBGRgY7duwgISEBg8HAtGnT+PDDDx+aaHsa2tra6OzspKSkhNzcXMaMGfPQwH5tba1xhtXU1GQMK6xdu9YkOh6Fh4eHcSOGTqfjypUr7N+/n/T0dF544QX8/PxEsWtvb090dDSffvop5ubmrF69Gn9/f1Qq1UNjeAqFgoCAAPbv309lZSWurq4m09LS0kJKSgqnTp3iyJEjdHR0EBERwcKFCwkKChJlGd0TampqOHLkCKdOncLLy4uIiAhRt93ejcFgQKPRoNPpsLCw4Pnnn2fixImixdOVSiUuLi7cuHGDa9eu4ejo2CPH2dDQQHx8PLdv3+7xpo1eOV2dTodaraaiogKFQmEcGcUuer8bPz+/Jz6wjY2NZGdnc+3aNcLDw0XJBgPGSgQXFxdsbW3p6Ojgxo0bJCYmsm3bNs6ePWtc/r/11lsmc7gAM2bMoKSkhJaWFg4dOsSECRN4/vnnjZ1WEARaW1s5efIkMTExlJSUoFQq8fLyYv78+aKV9XWdfdC15fXWrVuUlpYSExPDnj178PT0ZN68efdsHjE1vr6+zJo1i6CgIKZOnfrYwne5XI5cLqeyspKUlBRefvnlp7bbVSFSXV1tLM2KiYkhPT0dc3NzoqKiePvtt/u0HOtRaLVaLl68SHZ2NmZmZkybNo3IyMg+WbF2TYq6chIuLi6MHj36iYcU9YaBAweydu1aPvjgA77//nucnJwYOXLkE//frueosLCQ8vJynJ2dexwa65XTLS4uJjk5mdu3bxMQEEBISEif1aT2hJaWFmpraxk6dCgzZswQbUYhk8mQyWSUlpZSUFBARUUFu3btIj4+nsbGRpydnZk7dy5r1641+bIpKCiIQ4cO0d7eTl5eHjExMQiCgIeHB3K5nIaGBgoLC4mNjeXMmTPIZDI8PDxYvny5aHEzuLOcNhgMqNVq5HI5VlZWpKenc+HCBaytrYmMjGTRokWi2QcYPnw4H3/8cY/eo9fre7VJoq6ujurqanbv3s2OHTtob29HLpej1WrRarVYWFhga2tLZ2cnbW1tWFpa9ulk5X4KCwvZunUrqampBAYGMmfOHGPYQWwKCgr49NNPuXDhgvEkvO5uUnhazMzMeOWVV1Cr1Rw6dAgvLy+WLFmCh4fHA/dCEAR0Oh16vd6Y7Pvuu+9obGxk/fr1TyxFvJ9eOV21Wk19fT1yuZy//OUvomS/TYGZmRn9+/cnODiYlStXimJDJpNhaWmJmZkZJ0+eNJ6Tq9FokMvl+Pj4EBkZyYIFC0SJUy1cuJC8vDy2bt3KzZs32bJlC7m5ucZSmLy8PC5dumScTXQd2vHSSy+JugmgK/sfHx/P119/jUwmMxbcr1y5knXr1vW4VEdsurZHd5UaPg0//PADX3zxBc3NzVhYWODk5ERERATm5uakpKRQVlZGbm4uu3btYuLEifj5+TFo0CBsbGxEL1G7H41Gw5kzZ7h69SoKhYKQkBBeeOGFPrEtCALffvstycnJtLa2EhQUZDx6tC948803aWho4KeffiInJ4fFixczadKke1ZDXWdfV1ZWcvbsWQ4cOICXlxerV69+qpVQr3p7cnIydXV1ODo6Ym9v32e1fD3F0dGRCRMmdPuYv6fBxsaGOXPmcOvWLSoqKmhtbUWhUODg4MDkyZNZvHgx4eHhxrimGHz44YfI5XJiY2O5ceMGOTk5nD179h7nIZPJ6NevHyEhIbzxxhui77pydHRkw4YNWFpakpycjJWVFZMnT2bp0qWEhYU9kw7X1tYWR0dHzp0799Sfo1AosLKyYsqUKQQEBLB8+XJjCEer1ZKRkUFKSgrHjh0jJiYGW1tbQkJCWLZsmdHhmJubG0+lE4vOzk4SEhL48ccfKSgoYM6cOURERPTZrLuqqorKykrjTkgbG5sHjrsUEz8/P/72t78RGxvL4cOHyczMZNCgQcaBTy6X09HRQW1tLc3NzQwZMoSXXnqJNWvWPPWBTbIn7Ih67IsfffQRcXFxzJ8/nxUrVpjq6EZRvrO+K2vs7Ozc3ThVj3VoNBr27NnDN998Q1tbG4GBgSxatKi3SZKnuh5dBeYJCQlkZ2cb65ltbW1ZuHAh69ev72kcV5T78hSIrqO6upqNGzdy7tw5jh49+lQ6ur7poCtG/DC6Yr5Hjhxh+/bt5OTkIJPJjAORj48PcXFxT+o7j/LI3bomRUVFvP/++xw5cgSVSsW7777bm/r1Ht+bd955h8OHD9PY2Gj8/fXXX+9tMr7HOgwGA8eOHSM7O5vffvuNK1euoNPpsLe3Z8KECfj4+DBy5EhmzJiBnZ1ddycLD703vXK6IvFf83B3E0nHvYiuQ6vVkpmZyZdffvm4w+2f5esB3dSyZs0afvnlFzQaDatWrWLt2rW9qV9/lq/Js6JD3K/rkZD4/4iFhQVTp04V9dyQZ4WuTSGhoaF9sotUQuSv65GQkHi2kclkjBs3jsWLF4u23VniXqTwwqORdNyLpONenmUd8OxokXTc/0dTf/eShISEhMSjkcILEhISEn2I5HQlJCQk+hDJ6UpISEj0IZLTlZCQkOhDJKcrISEh0YdITldCQkKiD/lftwu/IjB4gO8AAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# We need matplotlib library to plot the dataset\n", + "import matplotlib.pyplot as plt\n", + "\n", + "# Plot the first 10 entries of the labels and the dataset\n", + "figure = plt.figure()\n", + "num_of_entries = 10\n", + "for index in range(1, num_of_entries + 1):\n", + " plt.subplot(6, 10, index)\n", + " plt.axis('off')\n", + " plt.imshow(dataloader.dataloader1.dataset.data[index].numpy().squeeze(), cmap='gray_r')\n", + " print(dataloader.dataloader2.dataset[index][0], end=\" \")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------------\n", + "Id 0 of the DataLoader 1: 0000c647-1520-46ef-ba38-c0108845a2e1\n", + "Id 0 of the DataLoader 2: 0000c647-1520-46ef-ba38-c0108845a2e1\n", + "----------------------------------------\n", + "Id 1 of the DataLoader 1: 000259fd-f80a-4167-b451-8ec0e9ac42bc\n", + "Id 1 of the DataLoader 2: 000259fd-f80a-4167-b451-8ec0e9ac42bc\n", + "----------------------------------------\n", + "Id 2 of the DataLoader 1: 00027ad7-ece2-4d05-861a-e7fb7cdabc58\n", + "Id 2 of the DataLoader 2: 00027ad7-ece2-4d05-861a-e7fb7cdabc58\n", + "----------------------------------------\n", + "Id 3 of the DataLoader 1: 00028968-fd8b-4634-af19-c6cd05c42710\n", + "Id 3 of the DataLoader 2: 00028968-fd8b-4634-af19-c6cd05c42710\n", + "----------------------------------------\n", + "Id 4 of the DataLoader 1: 0004c0df-5252-45e2-9c35-0437c086438b\n", + "Id 4 of the DataLoader 2: 0004c0df-5252-45e2-9c35-0437c086438b\n", + "----------------------------------------\n", + "Id 5 of the DataLoader 1: 00058d30-ab69-44d9-a95b-d8ea22bbf9ba\n", + "Id 5 of the DataLoader 2: 00058d30-ab69-44d9-a95b-d8ea22bbf9ba\n", + "----------------------------------------\n", + "Id 6 of the DataLoader 1: 00059558-4e98-4835-8af3-2b2a0d14e438\n", + "Id 6 of the DataLoader 2: 00059558-4e98-4835-8af3-2b2a0d14e438\n", + "----------------------------------------\n", + "Id 7 of the DataLoader 1: 00068c68-e706-4ca6-8dfb-0acb723a1cb4\n", + "Id 7 of the DataLoader 2: 00068c68-e706-4ca6-8dfb-0acb723a1cb4\n", + "----------------------------------------\n", + "Id 8 of the DataLoader 1: 00079c07-609f-411d-8bc2-a859783fdda4\n", + "Id 8 of the DataLoader 2: 00079c07-609f-411d-8bc2-a859783fdda4\n", + "----------------------------------------\n", + "Id 9 of the DataLoader 1: 0007a5d4-30ff-475b-b2a8-a82193f4002e\n", + "Id 9 of the DataLoader 2: 0007a5d4-30ff-475b-b2a8-a82193f4002e\n", + "----------------------------------------\n" + ] + } + ], + "source": [ + "print(\"----------------------------------------\")\n", + "for i in range(0,10):\n", + " print(\"Id\", i, \"of the DataLoader 1: \", dataloader.dataloader1.dataset.ids[i])\n", + " print(\"Id\", i, \"of the DataLoader 2: \", dataloader.dataloader2.dataset.ids[i])\n", + " print(\"----------------------------------------\")" ] }, {