|
3 | 3 | import warnings |
4 | 4 | from copy import deepcopy |
5 | 5 | from enum import Enum, auto |
6 | | -from typing import Any, Dict, Iterator, Optional, Set, Union |
| 6 | +from typing import Any, Dict, Iterator, Optional, Set, Tuple, Union |
7 | 7 |
|
8 | 8 | import numpy as np |
9 | 9 | import torch |
@@ -70,7 +70,9 @@ def __init__( |
70 | 70 | strict: bool = True, |
71 | 71 | prefer_deferred_runtime_asserts_over_guards: bool = False, |
72 | 72 | weight_streaming_budget: Optional[int] = None, |
73 | | - enabled_precisions: Optional[Set[Union[torch.dtype, dtype]]] = None, |
| 73 | + enabled_precisions: Union[ |
| 74 | + Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]] |
| 75 | + ] = _defaults.ENABLED_PRECISIONS, |
74 | 76 | **kwargs: Any, |
75 | 77 | ) -> None: |
76 | 78 | """ |
@@ -127,6 +129,10 @@ def __init__( |
127 | 129 | self.refit_state = RefitState() |
128 | 130 | self.pytorch_model = _make_refit_change_trigger(pytorch_model, self.refit_state) |
129 | 131 | self.original_model = pytorch_model |
| 132 | + if pytorch_model.training: |
| 133 | + logger.warning( |
| 134 | + "The model may be in training mode, which may affect the performance of the compiled model!" |
| 135 | + ) |
130 | 136 | # Process settings |
131 | 137 | self.gm: Any = None |
132 | 138 | self.exp_program: Any = None |
@@ -162,8 +168,6 @@ def __init__( |
162 | 168 | "Weight stremaing budget is not set. Using auto weight streaming budget" |
163 | 169 | ) |
164 | 170 | self.enabled_precisions = enabled_precisions |
165 | | - if self.enabled_precisions is None: |
166 | | - self.enabled_precisions = _defaults.ENABLED_PRECISIONS |
167 | 171 |
|
168 | 172 | cls = self.__class__ |
169 | 173 | self.__class__ = type( |
|
0 commit comments