Skip to content

[DO NOT MERGE] New memory abstraction and AMRL implementation #4374

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ml-agents/mlagents/trainers/policy/torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(
conditional_sigma=self.condition_sigma_on_obs,
tanh_squash=tanh_squash,
)
self.m_size = self.actor_critic.memory_size

self.actor_critic.to(TestingConfiguration.device)

Expand Down
78 changes: 78 additions & 0 deletions ml-agents/mlagents/trainers/torch/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,81 @@ def lstm_layer(
forget_bias
)
return lstm


class AMRLMax(torch.nn.Module):
"""
Implements Aggregation for LSTM as described here:
https://www.microsoft.com/en-us/research/publication/amrl-aggregated-memory-for-reinforcement-learning/
"""

def __init__(
self,
input_size: int,
hidden_size: int,
num_layers: int = 1,
batch_first: bool = True,
forget_bias: float = 1.0,
kernel_init: Initialization = Initialization.XavierGlorotUniform,
bias_init: Initialization = Initialization.Zero,
num_post_layers: int = 1,
):
super().__init__()
self.lstm = lstm_layer(
input_size,
hidden_size,
num_layers,
batch_first,
forget_bias,
kernel_init,
bias_init,
)
self.hidden_size = hidden_size
self.layers = []
for _ in range(num_post_layers):
self.layers.append(
linear_layer(
hidden_size,
hidden_size,
kernel_init=Initialization.KaimingHeNormal,
kernel_gain=1.0,
)
)
self.layers.append(Swish())
self.seq_layers = torch.nn.Sequential(*self.layers)

@property
def memory_size(self) -> int:
return self.hidden_size // 2 + 2 * self.hidden_size

def forward(self, input_tensor, memories):
# memories is 1/2 * hidden_size (accumulant) + hidden_size/2 (h0) + hidden_size/2 (c0)
acc, h0, c0 = torch.split(
memories,
[self.hidden_size // 2, self.hidden_size, self.hidden_size],
dim=-1,
)
hidden = (h0, c0)
all_c = []
m = acc.permute([1, 0, 2])
lstm_out, (h0_out, c0_out) = self.lstm(input_tensor, hidden)
h_half, other_half = torch.split(lstm_out, self.hidden_size // 2, dim=-1)
for t in range(h_half.shape[1]):
h_half_subt = h_half[:, t : t + 1, :]
m = AMRLMax.PassthroughMax.apply(m, h_half_subt)
all_c.append(m)
concat_c = torch.cat(all_c, dim=1)
concat_out = torch.cat([concat_c, other_half], dim=-1)
full_out = self.seq_layers(concat_out.reshape([-1, self.hidden_size]))
full_out = full_out.reshape([-1, input_tensor.shape[1], self.hidden_size])
output_mem = torch.cat([m.permute([1, 0, 2]), h0_out, c0_out], dim=-1)
return concat_out, output_mem

class PassthroughMax(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor1, tensor2):
return torch.max(tensor1, tensor2)

@staticmethod
def backward(ctx, grad_output):
return grad_output.clone(), grad_output.clone()
32 changes: 23 additions & 9 deletions ml-agents/mlagents/trainers/torch/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from mlagents.trainers.settings import NetworkSettings
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.torch.decoders import ValueHeads
from mlagents.trainers.torch.layers import lstm_layer
from mlagents.trainers.torch.layers import AMRLMax

ActivationFunction = Callable[[torch.Tensor], torch.Tensor]
EncoderFunction = Callable[
Expand Down Expand Up @@ -51,9 +51,9 @@ def __init__(
)

if self.use_lstm:
self.lstm = lstm_layer(self.h_size, self.m_size // 2, batch_first=True)
self.lstm = AMRLMax(self.h_size, self.m_size // 2, batch_first=True)
else:
self.lstm = None
self.lstm = None # type: ignore

def update_normalization(self, vec_inputs: List[torch.Tensor]) -> None:
for vec_input, vec_enc in zip(vec_inputs, self.vector_encoders):
Expand Down Expand Up @@ -104,10 +104,10 @@ def forward(
if self.use_lstm:
# Resize to (batch, sequence length, encoding size)
encoding = encoding.reshape([-1, sequence_length, self.h_size])
memories = torch.split(memories, self.m_size // 2, dim=-1)
# memories = torch.split(memories, self.m_size // 2, dim=-1)
encoding, memories = self.lstm(encoding, memories)
encoding = encoding.reshape([-1, self.m_size // 2])
memories = torch.cat(memories, dim=-1)
# memories = torch.cat(memories, dim=-1)
return encoding, memories


Expand Down Expand Up @@ -257,7 +257,7 @@ def __init__(
self.act_type = act_type
self.act_size = act_size
self.version_number = torch.nn.Parameter(torch.Tensor([2.0]))
self.memory_size = torch.nn.Parameter(torch.Tensor([0]))
self.memory_size_param = torch.nn.Parameter(torch.Tensor([0]))
self.is_continuous_int = torch.nn.Parameter(
torch.Tensor([int(act_type == ActionType.CONTINUOUS)])
)
Expand All @@ -279,6 +279,13 @@ def __init__(
self.encoding_size, act_size
)

@property
def memory_size(self) -> int:
if self.network_body.lstm is not None:
return self.network_body.lstm.memory_size
else:
return 0

def update_normalization(self, vector_obs: List[torch.Tensor]) -> None:
self.network_body.update_normalization(vector_obs)

Expand Down Expand Up @@ -327,7 +334,7 @@ def forward(
sampled_actions,
dists[0].pdf(sampled_actions),
self.version_number,
self.memory_size,
self.memory_size_param,
self.is_continuous_int,
self.act_size_vector,
)
Expand Down Expand Up @@ -425,6 +432,13 @@ def __init__(
stream_names, observation_shapes, use_network_settings
)

@property
def memory_size(self) -> int:
if self.network_body.lstm is not None:
return 2 * self.network_body.lstm.memory_size
else:
return 0

def critic_pass(
self,
vec_inputs: List[torch.Tensor],
Expand All @@ -435,7 +449,7 @@ def critic_pass(
actor_mem, critic_mem = None, None
if self.use_lstm:
# Use only the back half of memories for critic
actor_mem, critic_mem = torch.split(memories, self.half_mem_size, -1)
actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, -1)
value_outputs, critic_mem_out = self.critic(
vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length
)
Expand All @@ -456,7 +470,7 @@ def get_dist_and_value(
) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]:
if self.use_lstm:
# Use only the back half of memories for critic and actor
actor_mem, critic_mem = torch.split(memories, self.half_mem_size, dim=-1)
actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, dim=-1)
else:
critic_mem = None
actor_mem = None
Expand Down