-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Part 12 - Train an Encrypted Neural Network on Encrypted Data
- Loading branch information
harsh_kasyap
authored and
harsh_kasyap
committed
May 14, 2020
1 parent
7c995aa
commit 37cb5ec
Showing
1 changed file
with
171 additions
and
0 deletions.
There are no files selected for viewing
171 changes: 171 additions & 0 deletions
171
PySyft/examples/Part 12 - Train an Encrypted Neural Network on Encrypted Data.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"Falling back to insecure randomness since the required custom op could not be found for the installed version of TensorFlow. Fix this by compiling custom ops. Missing file was '/opt/conda/lib/python3.7/site-packages/tf_encrypted/operations/secure_random/secure_random_module_tf_1.15.2.so'\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"WARNING:tensorflow:From /opt/conda/lib/python3.7/site-packages/tf_encrypted/session.py:24: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.\n", | ||
"\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import torch\n", | ||
"import torch.nn as nn\n", | ||
"import torch.nn.functional as F\n", | ||
"import torch.optim as optim\n", | ||
"import syft as sy" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Set everything up\n", | ||
"hook = sy.TorchHook(torch) \n", | ||
"\n", | ||
"alice = sy.VirtualWorker(id=\"alice\", hook=hook)\n", | ||
"bob = sy.VirtualWorker(id=\"bob\", hook=hook)\n", | ||
"james = sy.VirtualWorker(id=\"james\", hook=hook)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# A Toy Dataset\n", | ||
"data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]])\n", | ||
"target = torch.tensor([[0],[0],[1],[1.]])\n", | ||
"\n", | ||
"# A Toy Model\n", | ||
"class Net(nn.Module):\n", | ||
" def __init__(self):\n", | ||
" super(Net, self).__init__()\n", | ||
" self.fc1 = nn.Linear(2, 2)\n", | ||
" self.fc2 = nn.Linear(2, 1)\n", | ||
"\n", | ||
" def forward(self, x):\n", | ||
" x = self.fc1(x)\n", | ||
" x = F.relu(x)\n", | ||
" x = self.fc2(x)\n", | ||
" return x\n", | ||
"model = Net()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/opt/conda/lib/python3.7/site-packages/syft/frameworks/torch/tensors/interpreters/additive_shared.py:79: UserWarning: Use dtype instead of field\n", | ||
" warnings.warn(\"Use dtype instead of field\")\n", | ||
"/opt/conda/lib/python3.7/site-packages/syft/frameworks/torch/tensors/interpreters/additive_shared.py:91: UserWarning: Default args selected\n", | ||
" warnings.warn(\"Default args selected\")\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# We encode everything\n", | ||
"data = data.fix_precision().share(bob, alice, crypto_provider=james, requires_grad=True)\n", | ||
"target = target.fix_precision().share(bob, alice, crypto_provider=james, requires_grad=True)\n", | ||
"model = model.fix_precision().share(bob, alice, crypto_provider=james, requires_grad=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"tensor(0.8350)\n", | ||
"tensor(0.7070)\n", | ||
"tensor(0.6110)\n", | ||
"tensor(0.5240)\n", | ||
"tensor(0.4390)\n", | ||
"tensor(0.3560)\n", | ||
"tensor(0.2800)\n", | ||
"tensor(0.2140)\n", | ||
"tensor(0.1660)\n", | ||
"tensor(0.1350)\n", | ||
"tensor(0.1010)\n", | ||
"tensor(0.0740)\n", | ||
"tensor(0.0610)\n", | ||
"tensor(0.0380)\n", | ||
"tensor(0.0330)\n", | ||
"tensor(0.0210)\n", | ||
"tensor(0.0140)\n", | ||
"tensor(0.0110)\n", | ||
"tensor(0.0080)\n", | ||
"tensor(0.0070)\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"opt = optim.SGD(params=model.parameters(),lr=0.1).fix_precision()\n", | ||
"\n", | ||
"for iter in range(20):\n", | ||
" # 1) erase previous gradients (if they exist)\n", | ||
" opt.zero_grad()\n", | ||
"\n", | ||
" # 2) make a prediction\n", | ||
" pred = model(data)\n", | ||
"\n", | ||
" # 3) calculate how much we missed\n", | ||
" loss = ((pred - target)**2).sum()\n", | ||
"\n", | ||
" # 4) figure out which weights caused us to miss\n", | ||
" loss.backward()\n", | ||
"\n", | ||
" # 5) change those weights\n", | ||
" opt.step()\n", | ||
"\n", | ||
" # 6) print our progress\n", | ||
" print(loss.get().float_precision())" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.7.6" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 4 | ||
} |