Skip to content

Commit

Permalink
Part 12 - Train an Encrypted Neural Network on Encrypted Data
Browse files Browse the repository at this point in the history
  • Loading branch information
harsh_kasyap authored and harsh_kasyap committed May 14, 2020
1 parent 7c995aa commit 37cb5ec
Showing 1 changed file with 171 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Falling back to insecure randomness since the required custom op could not be found for the installed version of TensorFlow. Fix this by compiling custom ops. Missing file was '/opt/conda/lib/python3.7/site-packages/tf_encrypted/operations/secure_random/secure_random_module_tf_1.15.2.so'\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:From /opt/conda/lib/python3.7/site-packages/tf_encrypted/session.py:24: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.\n",
"\n"
]
}
],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"import syft as sy"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Set everything up\n",
"hook = sy.TorchHook(torch) \n",
"\n",
"alice = sy.VirtualWorker(id=\"alice\", hook=hook)\n",
"bob = sy.VirtualWorker(id=\"bob\", hook=hook)\n",
"james = sy.VirtualWorker(id=\"james\", hook=hook)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# A Toy Dataset\n",
"data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]])\n",
"target = torch.tensor([[0],[0],[1],[1.]])\n",
"\n",
"# A Toy Model\n",
"class Net(nn.Module):\n",
" def __init__(self):\n",
" super(Net, self).__init__()\n",
" self.fc1 = nn.Linear(2, 2)\n",
" self.fc2 = nn.Linear(2, 1)\n",
"\n",
" def forward(self, x):\n",
" x = self.fc1(x)\n",
" x = F.relu(x)\n",
" x = self.fc2(x)\n",
" return x\n",
"model = Net()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.7/site-packages/syft/frameworks/torch/tensors/interpreters/additive_shared.py:79: UserWarning: Use dtype instead of field\n",
" warnings.warn(\"Use dtype instead of field\")\n",
"/opt/conda/lib/python3.7/site-packages/syft/frameworks/torch/tensors/interpreters/additive_shared.py:91: UserWarning: Default args selected\n",
" warnings.warn(\"Default args selected\")\n"
]
}
],
"source": [
"# We encode everything\n",
"data = data.fix_precision().share(bob, alice, crypto_provider=james, requires_grad=True)\n",
"target = target.fix_precision().share(bob, alice, crypto_provider=james, requires_grad=True)\n",
"model = model.fix_precision().share(bob, alice, crypto_provider=james, requires_grad=True)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(0.8350)\n",
"tensor(0.7070)\n",
"tensor(0.6110)\n",
"tensor(0.5240)\n",
"tensor(0.4390)\n",
"tensor(0.3560)\n",
"tensor(0.2800)\n",
"tensor(0.2140)\n",
"tensor(0.1660)\n",
"tensor(0.1350)\n",
"tensor(0.1010)\n",
"tensor(0.0740)\n",
"tensor(0.0610)\n",
"tensor(0.0380)\n",
"tensor(0.0330)\n",
"tensor(0.0210)\n",
"tensor(0.0140)\n",
"tensor(0.0110)\n",
"tensor(0.0080)\n",
"tensor(0.0070)\n"
]
}
],
"source": [
"opt = optim.SGD(params=model.parameters(),lr=0.1).fix_precision()\n",
"\n",
"for iter in range(20):\n",
" # 1) erase previous gradients (if they exist)\n",
" opt.zero_grad()\n",
"\n",
" # 2) make a prediction\n",
" pred = model(data)\n",
"\n",
" # 3) calculate how much we missed\n",
" loss = ((pred - target)**2).sum()\n",
"\n",
" # 4) figure out which weights caused us to miss\n",
" loss.backward()\n",
"\n",
" # 5) change those weights\n",
" opt.step()\n",
"\n",
" # 6) print our progress\n",
" print(loss.get().float_precision())"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

0 comments on commit 37cb5ec

Please sign in to comment.