Skip to content

Commit 055af09

Browse files
Check that ET and Eager are numericaly equivalent
Differential Revision: D61494957 Pull Request resolved: #4779
1 parent f887d72 commit 055af09

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

exir/tests/TARGETS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ python_unittest(
109109
deps = [
110110
"//caffe2:torch",
111111
"//executorch/exir:lib",
112+
"//executorch/extension/pybindings:portable_lib",
112113
],
113114
)
114115

@@ -209,6 +210,7 @@ python_unittest(
209210
"//executorch/exir/passes:debug_handle_generator_pass",
210211
"//executorch/exir/passes:insert_write_back_for_buffers_pass",
211212
"//executorch/exir/passes:lib",
213+
"//executorch/exir/passes:memory_format_ops_pass",
212214
"//executorch/exir/passes:normalize_view_copy_base_pass",
213215
"//executorch/exir/passes:remove_graph_asserts_pass",
214216
"//executorch/exir/passes:remove_mixed_type_operators",

exir/tests/test_joint_graph.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
import torch._dynamo
1212

1313
from executorch.exir import to_edge
14+
15+
from executorch.extension.pybindings.portable_lib import (
16+
_load_for_executorch_from_buffer,
17+
)
1418
from torch.export._trace import _export
1519
from torch.export.experimental import _export_forward_backward
1620
from torch.export.exported_program import OutputKind
@@ -89,3 +93,18 @@ def forward(self, x, y):
8993
.val.allocation_info.memory_offset_low,
9094
48,
9195
)
96+
97+
loss = m(*example_inputs)
98+
loss.backward()
99+
et_mod = _load_for_executorch_from_buffer(et.buffer)
100+
et_outputs = et_mod.forward(
101+
example_inputs
102+
) # ET outputs are [loss, grads, weights]
103+
104+
self.assertTrue(torch.allclose(loss, et_outputs[0]))
105+
self.assertTrue(
106+
torch.allclose(m.linear.weight.grad, et_outputs[1]) # pyre-ignore[6]
107+
)
108+
self.assertTrue(torch.allclose(m.linear.bias.grad, et_outputs[2]))
109+
self.assertTrue(torch.allclose(m.linear.weight, et_outputs[3]))
110+
self.assertTrue(torch.allclose(m.linear.bias, et_outputs[4]))

0 commit comments

Comments
 (0)