Skip to content

Commit 1a4b2cc

Browse files
author
Vincent Moens
committed
[Feature] TensorDictMap hashing functions
ghstack-source-id: 1c959ee Pull Request resolved: #2304
1 parent 194a5ff commit 1a4b2cc

File tree

5 files changed

+246
-13
lines changed

5 files changed

+246
-13
lines changed

test/test_recipe.py

Lines changed: 0 additions & 13 deletions
This file was deleted.

test/test_storage_map.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
import argparse
6+
import importlib.util
7+
8+
import pytest
9+
10+
import torch
11+
12+
from torchrl.data.map import BinaryToDecimal, RandomProjectionHash, SipHash
13+
14+
_has_gym = importlib.util.find_spec("gymnasium", None) or importlib.util.find_spec(
15+
"gym", None
16+
)
17+
18+
19+
class TestHash:
20+
def test_binary_to_decimal(self):
21+
binary_to_decimal = BinaryToDecimal(
22+
num_bits=4, device="cpu", dtype=torch.int32, convert_to_binary=True
23+
)
24+
binary = torch.Tensor([[0, 0, 1, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 10, 0]])
25+
decimal = binary_to_decimal(binary)
26+
27+
assert decimal.shape == (2,)
28+
assert (decimal == torch.Tensor([3, 2])).all()
29+
30+
def test_sip_hash(self):
31+
a = torch.rand((3, 2))
32+
b = a.clone()
33+
hash_module = SipHash(as_tensor=True)
34+
hash_a = torch.tensor(hash_module(a))
35+
hash_b = torch.tensor(hash_module(b))
36+
assert (hash_a == hash_b).all()
37+
38+
@pytest.mark.parametrize("n_components", [None, 14])
39+
@pytest.mark.parametrize("scale", [0.001, 0.01, 1, 100, 1000])
40+
def test_randomprojection_hash(self, n_components, scale):
41+
torch.manual_seed(0)
42+
r = RandomProjectionHash(n_components=n_components)
43+
x = torch.randn(10000, 100).mul_(scale)
44+
y = r(x)
45+
if n_components is None:
46+
assert r.n_components == r._N_COMPONENTS_DEFAULT
47+
else:
48+
assert r.n_components == n_components
49+
50+
assert y.shape == (10000,)
51+
assert y.unique().numel() == y.numel()
52+
53+
54+
if __name__ == "__main__":
55+
args, unknown = argparse.ArgumentParser().parse_known_args()
56+
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/data/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from .map import BinaryToDecimal, RandomProjectionHash, SipHash
67
from .postprocs import MultiStep
78
from .replay_buffers import (
89
Flat2TED,

torchrl/data/map/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from .hash import BinaryToDecimal, RandomProjectionHash, SipHash

torchrl/data/map/hash.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
6+
7+
from typing import Callable, List
8+
9+
import torch
10+
11+
12+
class BinaryToDecimal(torch.nn.Module):
13+
"""A Module to convert binaries encoded tensors to decimals.
14+
15+
This is a utility class that allow to convert a binary encoding tensor (e.g. `1001`) to
16+
its decimal value (e.g. `9`)
17+
18+
Args:
19+
num_bits (int): the number of bits to use for the bases table.
20+
The number of bits must be lower or equal to the input length and the input length
21+
must be divisible by ``num_bits``. If ``num_bits`` is lower than the number of
22+
bits in the input, the end result will be aggregated on the last dimension using
23+
:func:`~torch.sum`.
24+
device (torch.device): the device where inputs and outputs are to be expected.
25+
dtype (torch.dtype): the output dtype.
26+
convert_to_binary (bool, optional): if ``True``, the input to the ``forward``
27+
method will be cast to a binary input using :func:`~torch.heavyside`.
28+
Defaults to ``False``.
29+
30+
Examples:
31+
>>> binary_to_decimal = BinaryToDecimal(
32+
... num_bits=4, device="cpu", dtype=torch.int32, convert_to_binary=True
33+
... )
34+
>>> binary = torch.Tensor([[0, 0, 1, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 10, 0]])
35+
>>> decimal = binary_to_decimal(binary)
36+
>>> assert decimal.shape == (2,)
37+
>>> assert (decimal == torch.Tensor([3, 2])).all()
38+
"""
39+
40+
def __init__(
41+
self,
42+
num_bits: int,
43+
device: torch.device,
44+
dtype: torch.dtype,
45+
convert_to_binary: bool = False,
46+
):
47+
super().__init__()
48+
self.convert_to_binary = convert_to_binary
49+
self.bases = 2 ** torch.arange(num_bits - 1, -1, -1, device=device, dtype=dtype)
50+
self.num_bits = num_bits
51+
self.zero_tensor = torch.zeros((1,), device=device)
52+
53+
def forward(self, features: torch.Tensor) -> torch.Tensor:
54+
num_features = features.shape[-1]
55+
if self.num_bits > num_features:
56+
raise ValueError(f"{num_features=} is less than {self.num_bits=}")
57+
elif num_features % self.num_bits != 0:
58+
raise ValueError(f"{num_features=} is not divisible by {self.num_bits=}")
59+
60+
binary_features = (
61+
torch.heaviside(features, self.zero_tensor)
62+
if self.convert_to_binary
63+
else features
64+
)
65+
feature_parts = binary_features.reshape(shape=(-1, self.num_bits))
66+
digits = torch.vmap(torch.dot, (None, 0))(
67+
self.bases, feature_parts.to(self.bases.dtype)
68+
)
69+
digits = digits.reshape(shape=(-1, features.shape[-1] // self.num_bits))
70+
aggregated_digits = torch.sum(digits, dim=-1)
71+
return aggregated_digits
72+
73+
74+
class SipHash(torch.nn.Module):
75+
"""A Module to Compute SipHash values for given tensors.
76+
77+
A hash function module based on SipHash implementation in python.
78+
79+
Args:
80+
as_tensor (bool, optional): if ``True``, the bytes will be turned into integers
81+
through the builtin ``hash`` function and mapped to a tensor. Default: ``True``.
82+
83+
.. warning:: This module relies on the builtin ``hash`` function.
84+
To get reproducible results across runs, the ``PYTHONHASHSEED`` environment
85+
variable must be set before the code is run (changing this value during code
86+
execution is without effect).
87+
88+
Examples:
89+
>>> # Assuming we set PYTHONHASHSEED=0 prior to running this code
90+
>>> a = torch.tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]])
91+
>>> b = a.clone()
92+
>>> hash_module = SipHash(as_tensor=True)
93+
>>> hash_a = hash_module(a)
94+
>>> hash_a
95+
tensor([-4669941682990263259, -3778166555168484291, -9122128731510687521])
96+
>>> hash_b = hash_module(b)
97+
>>> assert (hash_a == hash_b).all()
98+
"""
99+
100+
def __init__(self, as_tensor: bool = True):
101+
super().__init__()
102+
self.as_tensor = as_tensor
103+
104+
def forward(self, x: torch.Tensor) -> torch.Tensor | List[bytes]:
105+
hash_values = []
106+
if x.dtype in (torch.bfloat16,):
107+
x = x.to(torch.float16)
108+
for x_i in x.detach().cpu().numpy():
109+
hash_value = x_i.tobytes()
110+
hash_values.append(hash_value)
111+
if not self.as_tensor:
112+
return hash_value
113+
result = torch.tensor([hash(x) for x in hash_values], dtype=torch.int64)
114+
return result
115+
116+
117+
class RandomProjectionHash(SipHash):
118+
"""A module that combines random projections with SipHash to get a low-dimensional tensor, easier to embed through :class:`~.SipHash`.
119+
120+
This module requires sklearn to be installed.
121+
122+
Keyword Args:
123+
n_components (int, optional): the low-dimensional number of components of the projections.
124+
Defaults to 16.
125+
dtype_cast (torch.dtype, optional): the dtype to cast the projection to.
126+
Defaults to ``torch.bfloat16``.
127+
as_tensor (bool, optional): if ``True``, the bytes will be turned into integers
128+
through the builtin ``hash`` function and mapped to a tensor. Default: ``True``.
129+
130+
.. warning:: This module relies on the builtin ``hash`` function.
131+
To get reproducible results across runs, the ``PYTHONHASHSEED`` environment
132+
variable must be set before the code is run (changing this value during code
133+
execution is without effect).
134+
135+
init_method: TODO
136+
"""
137+
138+
_N_COMPONENTS_DEFAULT = 16
139+
140+
def __init__(
141+
self,
142+
*,
143+
n_components: int | None = None,
144+
dtype_cast=torch.bfloat16,
145+
as_tensor: bool = True,
146+
init_method: Callable[[torch.Tensor], torch.Tensor | None] | None = None,
147+
**kwargs,
148+
):
149+
if n_components is None:
150+
n_components = self._N_COMPONENTS_DEFAULT
151+
152+
super().__init__(as_tensor=as_tensor)
153+
self.register_buffer("_n_components", torch.as_tensor(n_components))
154+
155+
self._init = False
156+
if init_method is None:
157+
init_method = torch.nn.init.normal_
158+
self.init_method = init_method
159+
160+
self.dtype_cast = dtype_cast
161+
self.register_buffer("transform", torch.nn.UninitializedBuffer())
162+
163+
@property
164+
def n_components(self):
165+
return self._n_components.item()
166+
167+
def fit(self, x):
168+
"""Fits the random projection to the input data."""
169+
self.transform.materialize(
170+
(x.shape[-1], self.n_components), dtype=self.dtype_cast, device=x.device
171+
)
172+
self.init_method(self.transform)
173+
self._init = True
174+
175+
def forward(self, x: torch.Tensor) -> torch.Tensor:
176+
if not self._init:
177+
self.fit(x)
178+
elif not self._init:
179+
raise RuntimeError(
180+
f"The {type(self).__name__} has not been initialized. Call fit before calling this method."
181+
)
182+
x = x.to(self.dtype_cast) @ self.transform
183+
return super().forward(x)

0 commit comments

Comments
 (0)