Skip to content

Commit 571d5aa

Browse files
committed
TensorSpec.encode domain check parametrization (#228)
1 parent aad5d04 commit 571d5aa

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

torchrl/data/tensor_specs.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from __future__ import annotations
77

8+
import os
89
from dataclasses import dataclass
910
from textwrap import indent
1011
from typing import (
@@ -41,6 +42,17 @@
4142

4243
INDEX_TYPING = Union[int, torch.Tensor, np.ndarray, slice, List]
4344

45+
_NO_CHECK_SPEC_ENCODE = os.environ.get("NO_CHECK_SPEC_ENCODE", False)
46+
if _NO_CHECK_SPEC_ENCODE in ("0", "False", False):
47+
_NO_CHECK_SPEC_ENCODE = False
48+
elif _NO_CHECK_SPEC_ENCODE in ("1", "True", True):
49+
_NO_CHECK_SPEC_ENCODE = True
50+
else:
51+
raise NotImplementedError(
52+
"NO_CHECK_SPEC_ENCODE should be in 'True', 'False', '0' or '1'. "
53+
f"Got {_NO_CHECK_SPEC_ENCODE} instead."
54+
)
55+
4456

4557
def _default_dtype_and_device(
4658
dtype: Union[None, torch.dtype],
@@ -214,7 +226,8 @@ def encode(self, val: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
214226
):
215227
val = val.copy()
216228
val = torch.as_tensor(val, dtype=self.dtype, device=self.device)
217-
self.assert_is_in(val)
229+
if not _NO_CHECK_SPEC_ENCODE:
230+
self.assert_is_in(val)
218231
return val
219232

220233
def to_numpy(self, val: torch.Tensor, safe: bool = True) -> np.ndarray:

0 commit comments

Comments
 (0)