|
1 | 1 | import inspect |
2 | 2 | import logging |
| 3 | +import warnings |
3 | 4 | from copy import deepcopy |
4 | 5 | from enum import Enum, auto |
5 | 6 | from typing import Any, Collection, Dict, Iterator, List, Optional, Set, Union |
@@ -464,6 +465,12 @@ def _process_kwarg_inputs(inputs: Any) -> Any: |
464 | 465 | ) |
465 | 466 |
|
466 | 467 | def forward(self, *args: Any, **kwargs: Any) -> Any: |
| 468 | + warnings.warn( |
| 469 | + "Direct calls to {self.__class__}.forward() are currently broken by due to https://github.com/pytorch/pytorch/issues/157183. Either call {self.__class__}(...) directly or use {self.__class__}._forward as a work around" |
| 470 | + ) |
| 471 | + return self._forward(*args, **kwargs) |
| 472 | + |
| 473 | + def _forward(self, *args: Any, **kwargs: Any) -> Any: |
467 | 474 | # Step 1: Check whether the input shape has changed |
468 | 475 | kwargs = MutableTorchTensorRTModule._process_kwarg_inputs(kwargs) |
469 | 476 | self._validate_inputs(*args, **kwargs) |
@@ -513,7 +520,9 @@ def __deepcopy__(self, memo: Any) -> Any: |
513 | 520 | return result |
514 | 521 |
|
515 | 522 | def __call__(self, *args: Any, **kwargs: Any) -> Any: |
516 | | - return self.forward(*args, **kwargs) |
| 523 | + # Due to https://github.com/pytorch/pytorch/issues/157183, we cannot use forward call, use _forward as a workaround. |
| 524 | + # This is a temporary fix. |
| 525 | + return self._forward(*args, **kwargs) |
517 | 526 |
|
518 | 527 | def __getattr__(self, name: str) -> Any: |
519 | 528 | if name in self.__dict__: |
|
0 commit comments