Skip to content

Commit

Permalink
Enable CrypTen export to PyTorch
Browse files Browse the repository at this point in the history
Summary:
Enables CrypTen to export to PyTorch models using `to_pytorch()` when initially imported to CrypTen using `from_pytorch()`.

This is a simple implementation for now. A more advanced implementation could export without initial `from_pytorch()` import.

Also added some additional testing to `test_onnx_converter.py` to compare model exports against the encrypted model output.

Reviewed By: karthikprasad

Differential Revision: D31938254

fbshipit-source-id: fb8f9c627d3ee77ccff38feed50c3fe9ec58e9a8
  • Loading branch information
knottb authored and facebook-github-bot committed Oct 29, 2021
1 parent 8cc958b commit 17f9d54
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 0 deletions.
12 changes: 12 additions & 0 deletions crypten/nn/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,18 @@ def _clear_unused_values():
# this should never happen:
raise ValueError("nn.Graph.forward() failed. Is graph unconnected?")

def to_pytorch(self):
if not hasattr(self, "pytorch_model"):
raise AttributeError("CrypTen Graph detached from PyTorch model.")
if self.encrypted:
raise ValueError(
"CrypTen model must be decrypted before calling to_pytorch()"
)
with torch.no_grad():
for name, param in self.pytorch_model.named_parameters():
param.set_(self._modules[name].data)
return self.pytorch_model


class Sequential(Graph):
"""
Expand Down
4 changes: 4 additions & 0 deletions crypten/nn/onnx_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# LICENSE file in the root directory of this source tree.


import copy
import io

import onnx
Expand Down Expand Up @@ -46,6 +47,9 @@ def from_pytorch(pytorch_model, dummy_input):
crypten_model = from_onnx(f)
f.close()

# set model architecture to export model back to pytorch model
crypten_model.pytorch_model = copy.deepcopy(pytorch_model)

# make sure training / eval setting is copied:
crypten_model.train(mode=pytorch_model.training)
return crypten_model
Expand Down
14 changes: 14 additions & 0 deletions test/test_onnx_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,8 @@ def forward(self, x):

self._check_training(model, x_train, y_train, loss_name)

self._check_model_export(model, x_train)

def test_from_pytorch_training_regression(self):
"""Tests from_pytorch CrypTen training for regression models"""
import torch.nn as nn
Expand Down Expand Up @@ -251,6 +253,7 @@ def forward(self, x):
model.encrypt()

self._check_training(model, x_train, y_train, "MSELoss")
self._check_model_export(model, x_train)

def _check_training(
self, model, x_train, y_train, loss_name, num_epochs=2, learning_rate=0.001
Expand Down Expand Up @@ -308,6 +311,17 @@ def _check_training(
f"{loss_name} has not decreased after training",
)

def _check_model_export(self, crypten_model, x_enc):
"""Checks that exported model returns the same results as crypten model"""
pytorch_model = crypten_model.decrypt().to_pytorch()
x_plain = x_enc.get_plain_text()

y_plain = pytorch_model(x_plain)
crypten_model.encrypt()
y_enc = crypten_model(x_enc)

self._check(y_enc, y_plain, msg="Model export failed.")

def test_get_operator_class(self):
"""Checks operator is a valid crypten module"""
Node = collections.namedtuple("Node", "op_type")
Expand Down

0 comments on commit 17f9d54

Please sign in to comment.