Skip to content

Commit dec63f7

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
Update verifier to handle None Tensor outputs (#8235)
Summary: conv returns (None, Tensor, Tensor) which is uncommon to see since the schema is (Tensor, Tensor, Tensor). This is to test that the verifier just ignores the None return value (since itll be unused in the runtime). Reviewed By: larryliu0820 Differential Revision: D69209059
1 parent ab805ff commit dec63f7

File tree

2 files changed

+55
-6
lines changed

2 files changed

+55
-6
lines changed

exir/verification/arg_validator.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,23 @@ def call_function( # noqa: C901 # pyre-fixme[14]
108108
for schema_ret in target._schema.returns:
109109
name = schema_ret.name if schema_ret.name else f"__ret_{ret_index}"
110110
kernel_ret = next(ret_iter)
111-
# Return value should not be in OptionalTensor type, so only check torch.TensorType here.
112-
if isinstance(schema_ret.type, torch.TensorType) and isinstance(
113-
kernel_ret, torch.Tensor
114-
):
115-
tensor_arg_types[name] = kernel_ret.dtype
116-
ret_index += 1
111+
if isinstance(schema_ret.type, torch.TensorType):
112+
if isinstance(kernel_ret, torch.Tensor):
113+
tensor_arg_types[name] = kernel_ret.dtype
114+
ret_index += 1
115+
# Exceptionally rarely (basically only backwards ops) you might see an OptionalTensor returned.
116+
# The schema of these ops though is typically -> (Tensor, Tensor ...). So the actual type
117+
# returned in cpp is empty/undefined tensor. There is no analogy to this in python so it
118+
# gets crudely mapped to None. To properly fix this core pytorch would have to change the
119+
# schema to (Tensor?, ...) which is just never going to happen. So we have to handle this case
120+
# here in the verifier and in memory planning as well.
121+
elif kernel_ret is None:
122+
tensor_arg_types[name] = schema_ret.default_value
123+
ret_index += 1
124+
else:
125+
raise InternalError(
126+
f"encountered return with type Tensor but value wasnt a tensor or None. schema:{target._schema}, output:{ret_index}"
127+
)
117128
elif schema_ret.type == torch.ListType.ofTensors() and all(
118129
isinstance(kernel_ret[i], torch.Tensor) for i in range(len(kernel_ret))
119130
):

exir/verification/test/test_verifier.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414
from executorch.exir import EdgeCompileConfig, to_edge
1515

1616
from executorch.exir.dialects._ops import ops
17+
from torch import nn
1718
from torch._export.verifier import SpecViolationError
1819
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
1920
from torch.export import export
21+
from torch.export.experimental import _export_forward_backward
2022

2123
from ..verifier import EXIREdgeDialectVerifier
2224

@@ -123,3 +125,39 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
123125
dim_order_verifier(stride_edge_model.exported_program())
124126
with self.assertRaises(SpecViolationError):
125127
stride_verifier(dim_order_edge_model.exported_program())
128+
129+
def test_none_return_verifier(self) -> None:
130+
class Net(nn.Module):
131+
def __init__(self):
132+
super().__init__()
133+
self.conv1 = nn.Conv2d(6, 6, 5)
134+
self.linear = nn.Linear(6, 2)
135+
136+
def forward(self, x):
137+
return self.linear(self.conv1(x).flatten(1))
138+
139+
class TrainingNet(nn.Module):
140+
def __init__(self, net):
141+
super().__init__()
142+
self.net = net
143+
self.loss = nn.CrossEntropyLoss()
144+
145+
def forward(self, input, label):
146+
pred = self.net(input)
147+
return self.loss(pred, label)
148+
149+
# conv returns (None, Tensor, Tensor) which is uncommon to see since
150+
# the schema is (Tensor, Tensor, Tensor). This is to test that
151+
# the verifier just ignores the None return value (since itll be
152+
# unused in the runtime).
153+
net = TrainingNet(Net())
154+
inputs = (torch.randn(1, 6, 5, 5), torch.ones(1, dtype=torch.int64))
155+
156+
export_model = export(net, inputs)
157+
export_model = _export_forward_backward(export_model)
158+
159+
edge = to_edge(export_model)
160+
161+
edge_verifier = EXIREdgeDialectVerifier()
162+
163+
edge_verifier(edge.exported_program())

0 commit comments

Comments
 (0)