|
5 | 5 |
|
6 | 6 | from __future__ import annotations
|
7 | 7 |
|
| 8 | +import os |
8 | 9 | from dataclasses import dataclass
|
9 | 10 | from textwrap import indent
|
10 | 11 | from typing import (
|
|
41 | 42 |
|
42 | 43 | INDEX_TYPING = Union[int, torch.Tensor, np.ndarray, slice, List]
|
43 | 44 |
|
| 45 | +_NO_CHECK_SPEC_ENCODE = os.environ.get("NO_CHECK_SPEC_ENCODE", False) |
| 46 | +if _NO_CHECK_SPEC_ENCODE in ("0", "False", False): |
| 47 | + _NO_CHECK_SPEC_ENCODE = False |
| 48 | +elif _NO_CHECK_SPEC_ENCODE in ("1", "True", True): |
| 49 | + _NO_CHECK_SPEC_ENCODE = True |
| 50 | +else: |
| 51 | + raise NotImplementedError( |
| 52 | + "NO_CHECK_SPEC_ENCODE should be in 'True', 'False', '0' or '1'. " |
| 53 | + f"Got {_NO_CHECK_SPEC_ENCODE} instead." |
| 54 | + ) |
| 55 | + |
44 | 56 |
|
45 | 57 | def _default_dtype_and_device(
|
46 | 58 | dtype: Union[None, torch.dtype],
|
@@ -214,7 +226,8 @@ def encode(self, val: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
|
214 | 226 | ):
|
215 | 227 | val = val.copy()
|
216 | 228 | val = torch.as_tensor(val, dtype=self.dtype, device=self.device)
|
217 |
| - self.assert_is_in(val) |
| 229 | + if not _NO_CHECK_SPEC_ENCODE: |
| 230 | + self.assert_is_in(val) |
218 | 231 | return val
|
219 | 232 |
|
220 | 233 | def to_numpy(self, val: torch.Tensor, safe: bool = True) -> np.ndarray:
|
|
0 commit comments