diff --git a/.circleci/config.yml b/.circleci/config.yml index 256ef8c8..13f387a7 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -34,7 +34,7 @@ jobs: no_output_timeout: 3h command: | . ~/crypten-test/bin/activate - echo 'for i in $(ls test/test_*.py | grep -Ev "test_(context|benchmark|tensorboard|models|cuda)"); do python3 -m unittest $i; (($? != 0)) && exit 1; done; exit 0' > run_tests.sh + echo 'for i in $(ls test/test_*.py | grep -Ev "test_(context|benchmark|models)"); do python3 -m unittest $i; (($? != 0)) && exit 1; done; exit 0' > run_tests.sh bash ./run_tests.sh - run: name: Linear svm example diff --git a/crypten/common/serial.py b/crypten/common/serial.py index 7ecbba8f..0818f22e 100644 --- a/crypten/common/serial.py +++ b/crypten/common/serial.py @@ -88,10 +88,10 @@ class RestrictedUnpickler(pickle.Unpickler): "torch.ByteStorage", "torch.DoubleStorage", "torch.FloatStorage", - # "torch._C.HalfStorageBase", - # "torch._C.QInt32StorageBase", - # "torch._C.QInt8StorageBase", - # "torch.storage._TypedStorage", + "torch._C.HalfStorageBase", + "torch._C.QInt32StorageBase", + "torch._C.QInt8StorageBase", + "torch.storage._TypedStorage", ] for item in __ALLOWLIST: diff --git a/setup.py b/setup.py index aed2e495..d58710dc 100644 --- a/setup.py +++ b/setup.py @@ -52,5 +52,5 @@ author=AUTHOR, license=LICENSE, tests_require=["pytest"], - data_files=[("/configs", ["configs/default.yaml"])], + data_files=[("configs", ["configs/default.yaml"])], ) diff --git a/test/test_debug.py b/test/test_debug.py index c4ab158f..6e9935f8 100644 --- a/test/test_debug.py +++ b/test/test_debug.py @@ -55,7 +55,6 @@ def test_correctness_validation(self): encrypted_tensor.add(10) # Ensure incorrect validation works properly for value - # tensor2 = get_random_test_tensor(size=(2, 2), is_float=True) encrypted_tensor.add = lambda y: crypten.cryptensor(tensor) with self.assertRaises(ValueError): encrypted_tensor.add(10) diff --git a/test/test_gradients.py b/test/test_gradients.py index 04c19c12..d4201b88 100644 --- a/test/test_gradients.py +++ b/test/test_gradients.py @@ -1222,7 +1222,6 @@ def tearDown(self): super(TestTFP, self).tearDown() -# @unittest.skip("Almost all TTP tests are timing out") class TestTTP(MultiProcessTestCase, TestGradients): def setUp(self): self._original_provider = cfg.mpc.provider