Skip to content

Commit

Permalink
[jit][edge] Pass through dynamic type for DictType. (pytorch#74025)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#74025

When users are trying to inspect IValues out of the Lite Interpreter, dynamic types are still attached, therefore torch::jit::toPyObject will fail on these dynamic types while converting dictionary keys.
We should just let dynamic types pass through under this corner case since they won't be used by anything later.
ghstack-source-id: 151051826

Test Plan:
buck test //caffe2/test:mobile -- -r 'test_bundled_input_with_dynamic_type'

without patch:
```
BUILD SUCCEEDED
Tpx test run coordinator for Facebook. See https://fburl.com/tpx for details.
Running with tpx session id: c6693277-2dad-4882-97c7-f69c58f67259
Trace available for this run at /tmp/tpx-20220310-000040.948069-c6693277-2dad-4882-97c7-f69c58f67259/trace.log
RemoteExecution session id: reSessionID-c6693277-2dad-4882-97c7-f69c58f67259-tpx
Started reporting to test run: https://www.internalfb.com/intern/testinfra/testrun/6473924544183693
    ✓ ListingSuccess: caffe2/test:mobile : 40 tests discovered (2.122)
    ✗ Fail: caffe2/test:mobile - test_bundled_input_with_dynamic_type (mobile.test_lite_script_module.TestLiteScriptQuantizedModule) (3.059)
Test output:
> RuntimeError: Cannot create dict for key type 'Dynamic<8>', only int, float, complex, Tensor, device and string keys are supported
  File "/usr/local/fbcode/platform009/lib/python3.8/unittest/case.py", line 60, in testPartExecutor
    yield
  File "/usr/local/fbcode/platform009/lib/python3.8/unittest/case.py", line 676, in run
    self._callTestMethod(testMethod)
  File "/usr/local/fbcode/platform009/lib/python3.8/unittest/case.py", line 633, in _callTestMethod
    method()
  File "/data/users/zhxchen17/fbsource/fbcode/buck-out/dbg/gen/caffe2/test/mobile#binary,link-tree/mobile/test_lite_script_module.py", line 558, in test_bundled_input_with_dynamic_type
    i = mobile_module.run_method("get_all_bundled_inputs")
  File "/data/users/zhxchen17/fbsource/fbcode/buck-out/dbg/gen/caffe2/test/mobile#binary,link-tree/torch/jit/mobile/__init__.py", line 69, in run_method
    return self._c.run_method(method_name, input)
stdout:

stderr:

Summary
  Fail: 1
    ✗ caffe2/test:mobile - test_bundled_input_with_dynamic_type (mobile.test_lite_script_module.TestLiteScriptQuantizedModule)
  ListingSuccess: 1
If you need help understanding your runs, please follow the wiki: https://fburl.com/posting_in_tpx_users
Finished test run: https://www.internalfb.com/intern/testinfra/testrun/6473924544183693
```

Reviewed By: cccclai

Differential Revision: D34780805

fbshipit-source-id: 88b139c5e91becc031e4b06e055a78a52a429c09
(cherry picked from commit 41abbac)
  • Loading branch information
zhxchen17 authored and pytorchmergebot committed Mar 11, 2022
1 parent 794f813 commit 75ad6fe
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 2 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/core/dynamic_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ class DynamicType : public SharedType {
const Arguments& arguments() const {
return arguments_;
}
TypeKind dynamicKind() const;
TORCH_API TypeKind dynamicKind() const;

// Should be used only on the server side to restore static type information.
#ifndef C10_MOBILE
Expand Down
6 changes: 5 additions & 1 deletion aten/src/ATen/core/jit_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -866,7 +866,11 @@ struct TORCH_API DictType : public SharedType {
static const TypeKind Kind = TypeKind::DictType;

static DictTypePtr create(TypePtr key, TypePtr value) {
switch (key->kind()) {
auto kind = key->kind();
if (auto dyn = key->castRaw<DynamicType>()) {
kind = dyn->dynamicKind();
}
switch (kind) {
case TypeKind::AnyType:
case TypeKind::IntType:
case TypeKind::BoolType:
Expand Down
43 changes: 43 additions & 0 deletions test/mobile/test_lite_script_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,49 @@ def forward(self, x):
input = torch.randn(4, 1, 4, 4)
self._compare_script_and_mobile(model=model_int8, input=input)

def test_bundled_input_with_dynamic_type(self):
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(
self,
x: Dict[int, torch.Tensor],
y: Dict[int, torch.Tensor],
z: Dict[int, torch.Tensor],
):
return x

model = Model()
script_module = torch.jit.script(model)

sample_input = {
script_module.forward: [
(
{0: torch.ones(1)},
{1: torch.ones(1)},
{2: torch.ones(1)},
)
]
}

bundled_model = torch.utils.bundled_inputs.bundle_inputs(
script_module, sample_input
)

buf = bundled_model._save_to_buffer_for_lite_interpreter()
mobile_module = _load_for_lite_interpreter(io.BytesIO(buf))

i = mobile_module.run_method("get_all_bundled_inputs")

self.assertEqual(
i[0],
(
{0: torch.ones(1)},
{1: torch.ones(1)},
{2: torch.ones(1)},
),
)

if __name__ == '__main__':
run_tests()

0 comments on commit 75ad6fe

Please sign in to comment.