Skip to content

[Feature] Implemented device argument for modules.models #524

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions test/test_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,25 +227,23 @@ def test_gsde(
):
torch.manual_seed(0)
if gSDE:
model = torch.nn.LazyLinear(action_dim)
model = torch.nn.LazyLinear(action_dim, device=device)
in_keys = ["observation"]
module = TensorDictSequential(
TensorDictModule(model, in_keys=in_keys, out_keys=["action"]),
TensorDictModule(
LazygSDEModule(),
LazygSDEModule(device=device),
in_keys=["action", "observation", "_eps_gSDE"],
out_keys=["loc", "scale", "action", "_eps_gSDE"],
),
).to(device)
)
distribution_class = IndependentNormal
distribution_kwargs = {}
else:
in_keys = ["observation"]
model = torch.nn.LazyLinear(action_dim * 2)
model = torch.nn.LazyLinear(action_dim * 2, device=device)
wrapper = NormalParamWrapper(model)
module = TensorDictModule(
wrapper, in_keys=in_keys, out_keys=["loc", "scale"]
).to(device)
module = TensorDictModule(wrapper, in_keys=in_keys, out_keys=["loc", "scale"])
distribution_class = TanhNormal
distribution_kwargs = {"min": -bound, "max": bound}
spec = NdBoundedTensorSpec(
Expand Down
96 changes: 85 additions & 11 deletions test/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,17 @@
FunctionalModule,
FunctionalModuleWithBuffers,
)
from torchrl.modules.models import MLP, NoisyLazyLinear, NoisyLinear
from torchrl.modules.models import ConvNet, MLP, NoisyLazyLinear, NoisyLinear
from torchrl.modules.models.utils import SquashDims


@pytest.mark.parametrize("in_features", [3, 10, None])
@pytest.mark.parametrize("out_features", [3, (3, 10)])
@pytest.mark.parametrize("depth, num_cells", [(3, 32), (None, (32, 32, 32))])
@pytest.mark.parametrize("activation_kwargs", [{"inplace": True}, {}])
@pytest.mark.parametrize(
"activation_class, activation_kwargs",
[(nn.ReLU, {"inplace": True}), (nn.ReLU, {}), (nn.PReLU, {})],
)
@pytest.mark.parametrize(
"norm_class, norm_kwargs",
[(nn.LazyBatchNorm1d, {}), (nn.BatchNorm1d, {"num_features": 32})],
Expand All @@ -45,6 +49,7 @@ def test_mlp(
out_features,
depth,
num_cells,
activation_class,
activation_kwargs,
bias_last_layer,
norm_class,
Expand All @@ -61,14 +66,15 @@ def test_mlp(
out_features=out_features,
depth=depth,
num_cells=num_cells,
activation_class=nn.ReLU,
activation_class=activation_class,
activation_kwargs=activation_kwargs,
norm_class=norm_class,
norm_kwargs=norm_kwargs,
bias_last_layer=bias_last_layer,
single_bias_last_layer=False,
layer_class=layer_class,
).to(device)
device=device,
)
if in_features is None:
in_features = 5
x = torch.randn(batch, in_features, device=device)
Expand All @@ -77,6 +83,72 @@ def test_mlp(
assert y.shape == torch.Size([batch, *out_features])


@pytest.mark.parametrize("in_features", [3, 10, None])
@pytest.mark.parametrize(
"input_size, depth, num_cells, kernel_sizes, strides, paddings, expected_features",
[(100, None, None, 3, 1, 0, 32 * 94 * 94), (100, 3, 32, 3, 1, 1, 32 * 100 * 100)],
)
@pytest.mark.parametrize(
"activation_class, activation_kwargs",
[(nn.ReLU, {"inplace": True}), (nn.ReLU, {}), (nn.PReLU, {})],
)
@pytest.mark.parametrize(
"norm_class, norm_kwargs",
[(None, None), (nn.LazyBatchNorm2d, {}), (nn.BatchNorm2d, {"num_features": 32})],
)
@pytest.mark.parametrize("bias_last_layer", [True, False])
@pytest.mark.parametrize(
"aggregator_class, aggregator_kwargs",
[(SquashDims, {})],
)
@pytest.mark.parametrize("squeeze_output", [False])
@pytest.mark.parametrize("device", get_available_devices())
def test_convnet(
in_features,
depth,
num_cells,
kernel_sizes,
strides,
paddings,
activation_class,
activation_kwargs,
norm_class,
norm_kwargs,
bias_last_layer,
aggregator_class,
aggregator_kwargs,
squeeze_output,
device,
input_size,
expected_features,
seed=0,
):
torch.manual_seed(seed)
batch = 2
convnet = ConvNet(
in_features=in_features,
depth=depth,
num_cells=num_cells,
kernel_sizes=kernel_sizes,
strides=strides,
paddings=paddings,
activation_class=activation_class,
activation_kwargs=activation_kwargs,
norm_class=norm_class,
norm_kwargs=norm_kwargs,
bias_last_layer=bias_last_layer,
aggregator_class=aggregator_class,
aggregator_kwargs=aggregator_kwargs,
squeeze_output=squeeze_output,
device=device,
)
if in_features is None:
in_features = 5
x = torch.randn(batch, in_features, input_size, input_size, device=device)
y = convnet(x)
assert y.shape == torch.Size([batch, expected_features])


@pytest.mark.parametrize(
"layer_class",
[
Expand All @@ -87,7 +159,7 @@ def test_mlp(
@pytest.mark.parametrize("device", get_available_devices())
def test_noisy(layer_class, device, seed=0):
torch.manual_seed(seed)
layer = layer_class(3, 4).to(device)
layer = layer_class(3, 4, device=device)
x = torch.randn(10, 3, device=device)
y1 = layer(x)
layer.reset_noise()
Expand All @@ -106,25 +178,25 @@ def test_value_based_policy(device):
action_spec = OneHotDiscreteTensorSpec(action_dim)

def make_net():
net = MLP(in_features=obs_dim, out_features=action_dim, depth=2)
net = MLP(in_features=obs_dim, out_features=action_dim, depth=2, device=device)
for mod in net.modules():
if hasattr(mod, "bias") and mod.bias is not None:
mod.bias.data.zero_()
return net

actor = QValueActor(spec=action_spec, module=make_net(), safe=True).to(device)
actor = QValueActor(spec=action_spec, module=make_net(), safe=True)
obs = torch.zeros(2, obs_dim, device=device)
td = TensorDict(batch_size=[2], source={"observation": obs})
action = actor(td).get("action")
assert (action.sum(-1) == 1).all()

actor = QValueActor(spec=action_spec, module=make_net(), safe=False).to(device)
actor = QValueActor(spec=action_spec, module=make_net(), safe=False)
obs = torch.randn(2, obs_dim, device=device)
td = TensorDict(batch_size=[2], source={"observation": obs})
action = actor(td).get("action")
assert (action.sum(-1) == 1).all()

actor = QValueActor(spec=action_spec, module=make_net(), safe=False).to(device)
actor = QValueActor(spec=action_spec, module=make_net(), safe=False)
obs = torch.zeros(2, obs_dim, device=device)
td = TensorDict(batch_size=[2], source={"observation": obs})
action = actor(td).get("action")
Expand Down Expand Up @@ -198,7 +270,8 @@ def test_lstm_net(device, out_features, hidden_size, num_layers, has_precond_hid
"num_layers": num_layers,
},
{"out_features": hidden_size},
).to(device)
device=device,
)
# test single step vs multi-step
x = torch.randn(batch, time_steps, in_features, device=device)
x_unbind = x.unbind(1)
Expand Down Expand Up @@ -264,7 +337,8 @@ def test_lstm_net_nobatch(device, out_features, hidden_size):
out_features,
{"input_size": hidden_size, "hidden_size": hidden_size},
{"out_features": hidden_size},
).to(device)
device=device,
)
# test single step vs multi-step
x = torch.randn(time_steps, in_features, device=device)
x_unbind = x.unbind(0)
Expand Down
38 changes: 22 additions & 16 deletions torchrl/modules/models/exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class NoisyLinear(nn.Linear):
out_features (int): out features dimension
bias (bool): if True, a bias term will be added to the matrix multiplication: Ax + b.
default: True
device (str, int or torch.device, optional): device of the layer.
device (DEVICE_TYPING, optional): device of the layer.
default: "cpu"
dtype (torch.dtype, optional): dtype of the parameters.
default: None
Expand Down Expand Up @@ -157,8 +157,7 @@ class NoisyLazyLinear(LazyModuleMixin, NoisyLinear):
out_features (int): out features dimension
bias (bool): if True, a bias term will be added to the matrix multiplication: Ax + b.
default: True
device (str, int or torch.device, optional): device of the layer.
default: "cpu"
device (DEVICE_TYPING, optional): device of the layer.
dtype (torch.dtype, optional): dtype of the parameters.
default: None
std_init (scalar): initial value of the Gaussian standard deviation before optimization.
Expand All @@ -173,7 +172,7 @@ def __init__(
dtype: Optional[torch.dtype] = None,
std_init: float = 0.1,
):
super().__init__(0, 0, False)
super().__init__(0, 0, False, device=device)
self.out_features = out_features
self.std_init = std_init

Expand Down Expand Up @@ -260,6 +259,7 @@ class gSDEModule(nn.Module):
scale_max (float, optional): max value of the scale.
transform (torch.distribution.Transform, optional): a transform to apply
to the sampled action.
device (DEVICE_TYPING, optional): device to create the model on.

Examples:
>>> from torchrl.modules import TensorDictModule, TensorDictSequential, ProbabilisticActor, TanhNormal
Expand Down Expand Up @@ -308,6 +308,7 @@ def __init__(
scale_max: float = 10.0,
learn_sigma: bool = True,
transform: Optional[d.Transform] = None,
device: Optional[DEVICE_TYPING] = None,
) -> None:
super().__init__()
self.action_dim = action_dim
Expand All @@ -321,18 +322,22 @@ def __init__(
sigma_init = inv_softplus(math.sqrt((1.0 - scale_min) / state_dim))
self.register_parameter(
"log_sigma",
nn.Parameter(torch.zeros((action_dim, state_dim), requires_grad=True)),
nn.Parameter(
torch.zeros(
(action_dim, state_dim), requires_grad=True, device=device
)
),
)
else:
if sigma_init is None:
sigma_init = math.sqrt((1.0 - scale_min) / state_dim)
self.register_buffer(
"_sigma",
torch.full((action_dim, state_dim), sigma_init),
torch.full((action_dim, state_dim), sigma_init, device=device),
)

if sigma_init != 0.0:
self.register_buffer("sigma_init", torch.tensor(sigma_init))
self.register_buffer("sigma_init", torch.tensor(sigma_init, device=device))

@property
def sigma(self):
Expand Down Expand Up @@ -417,11 +422,8 @@ def __init__(
scale_max: float = 10.0,
learn_sigma: bool = True,
transform: Optional[d.Transform] = None,
device: Optional[DEVICE_TYPING] = None,
) -> None:
factory_kwargs = {
"device": torch.device("cpu"),
"dtype": torch.get_default_dtype(),
}
super().__init__(
0,
0,
Expand All @@ -430,7 +432,12 @@ def __init__(
scale_max=scale_max,
learn_sigma=learn_sigma,
transform=transform,
device=device,
)
factory_kwargs = {
"device": device,
"dtype": torch.get_default_dtype(),
}
self._sigma_init = sigma_init
self.sigma_init = UninitializedBuffer(**factory_kwargs)
if learn_sigma:
Expand All @@ -445,18 +452,17 @@ def initialize_parameters(
self, mu: torch.Tensor, state: torch.Tensor, _eps_gSDE: torch.Tensor
) -> None:
if self.has_uninitialized_params():
device = mu.device
action_dim = mu.shape[-1]
state_dim = state.shape[-1]
with torch.no_grad():
if state.ndimension() > 2:
state = state.flatten(0, -2).squeeze(0)
if state.ndimension() == 1:
state_flatten_var = torch.ones(1).to(device)
state_flatten_var = torch.ones(1, device=mu.device)
else:
state_flatten_var = state.pow(2).mean(dim=0).reciprocal()

self.sigma_init.materialize(state_flatten_var.shape, device=device)
self.sigma_init.materialize(state_flatten_var.shape)
if self.learn_sigma:
if self._sigma_init is None:
state_flatten_var.clamp_min_(self.scale_min)
Expand All @@ -471,7 +477,7 @@ def initialize_parameters(
)
)

self.log_sigma.materialize((action_dim, state_dim), device=device)
self.log_sigma.materialize((action_dim, state_dim))
self.log_sigma.data.copy_(self.sigma_init.expand_as(self.log_sigma))

else:
Expand All @@ -483,5 +489,5 @@ def initialize_parameters(
self.sigma_init.data.copy_(
(state_flatten_var / state_dim).sqrt() * self._sigma_init
)
self._sigma.materialize((action_dim, state_dim), device=device)
self._sigma.materialize((action_dim, state_dim))
self._sigma.data.copy_(self.sigma_init.expand_as(self._sigma))
Loading