|
2 | 2 | import torch_tensorrt
|
3 | 3 | from torch.testing._internal.common_utils import TestCase, run_tests
|
4 | 4 |
|
5 |
| -from ..testing_utilities import lower_graph_testing |
| 5 | +from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing |
6 | 6 |
|
7 | 7 |
|
8 | 8 | class TestFakeTensors(TestCase):
|
@@ -57,6 +57,7 @@ def forward(self, x):
|
57 | 57 | self.assertAlmostEqual(
|
58 | 58 | max_diff,
|
59 | 59 | 0,
|
| 60 | + DECIMALS_OF_AGREEMENT, |
60 | 61 | msg=f"MulInt TRT outputs don't match with the original model.",
|
61 | 62 | )
|
62 | 63 | torch._dynamo.reset()
|
@@ -113,6 +114,7 @@ def forward(self, x):
|
113 | 114 | self.assertAlmostEqual(
|
114 | 115 | max_diff,
|
115 | 116 | 0,
|
| 117 | + DECIMALS_OF_AGREEMENT, |
116 | 118 | msg=f"AddFloat TRT outputs don't match with the original model.",
|
117 | 119 | )
|
118 | 120 |
|
@@ -157,5 +159,88 @@ def forward(self, x):
|
157 | 159 | torch._dynamo.reset()
|
158 | 160 |
|
159 | 161 |
|
| 162 | +class TestInputModifications(TestCase): |
| 163 | + def test_input_modifications_add(self): |
| 164 | + class InplaceAdd(torch.nn.Module): |
| 165 | + def forward(self, x): |
| 166 | + x += 3 |
| 167 | + y = x + 1 |
| 168 | + return y |
| 169 | + |
| 170 | + inputs = [ |
| 171 | + torch.rand( |
| 172 | + 3, |
| 173 | + 5, |
| 174 | + 7, |
| 175 | + ).cuda(), |
| 176 | + ] |
| 177 | + |
| 178 | + fx_graph = torch.fx.symbolic_trace(InplaceAdd()) |
| 179 | + |
| 180 | + # Validate that the results between Torch and Torch-TRT are similar |
| 181 | + optimized_model = torch_tensorrt.compile( |
| 182 | + fx_graph, |
| 183 | + "torch_compile", |
| 184 | + inputs, |
| 185 | + min_block_size=1, |
| 186 | + pass_through_build_failures=True, |
| 187 | + ) |
| 188 | + optimized_model_results = optimized_model(*inputs).detach().cpu() |
| 189 | + torch_model_results = fx_graph(*inputs).detach().cpu() |
| 190 | + |
| 191 | + max_diff = float( |
| 192 | + torch.max(torch.abs(optimized_model_results - torch_model_results)) |
| 193 | + ) |
| 194 | + self.assertAlmostEqual( |
| 195 | + max_diff, |
| 196 | + 0, |
| 197 | + DECIMALS_OF_AGREEMENT, |
| 198 | + msg=f"InplaceAdd TRT outputs don't match with the original model.", |
| 199 | + ) |
| 200 | + torch._dynamo.reset() |
| 201 | + |
| 202 | + def test_input_modifications_mul(self): |
| 203 | + class InplaceMul(torch.nn.Module): |
| 204 | + def forward(self, x): |
| 205 | + x *= 5.0 |
| 206 | + x *= 1.9 |
| 207 | + y = x + 1 |
| 208 | + y /= 1.3 |
| 209 | + return y |
| 210 | + |
| 211 | + inputs = [ |
| 212 | + torch.rand( |
| 213 | + 1, |
| 214 | + 3, |
| 215 | + 5, |
| 216 | + 7, |
| 217 | + ).cuda(), |
| 218 | + ] |
| 219 | + |
| 220 | + fx_graph = torch.fx.symbolic_trace(InplaceMul()) |
| 221 | + |
| 222 | + # Validate that the results between Torch and Torch-TRT are similar |
| 223 | + optimized_model = torch_tensorrt.compile( |
| 224 | + fx_graph, |
| 225 | + "torch_compile", |
| 226 | + inputs, |
| 227 | + min_block_size=1, |
| 228 | + pass_through_build_failures=True, |
| 229 | + ) |
| 230 | + optimized_model_results = optimized_model(*inputs).detach().cpu() |
| 231 | + torch_model_results = fx_graph(*inputs).detach().cpu() |
| 232 | + |
| 233 | + max_diff = float( |
| 234 | + torch.max(torch.abs(optimized_model_results - torch_model_results)) |
| 235 | + ) |
| 236 | + self.assertAlmostEqual( |
| 237 | + max_diff, |
| 238 | + 0, |
| 239 | + DECIMALS_OF_AGREEMENT, |
| 240 | + msg=f"InplaceMul TRT outputs don't match with the original model.", |
| 241 | + ) |
| 242 | + torch._dynamo.reset() |
| 243 | + |
| 244 | + |
160 | 245 | if __name__ == "__main__":
|
161 | 246 | run_tests()
|
0 commit comments