Skip to content

[poca] Make Observation Encoder a module #5093

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 121 additions & 55 deletions ml-agents/mlagents/trainers/torch/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,19 @@
from mlagents.trainers.torch.action_model import ActionModel
from mlagents.trainers.torch.agent_action import AgentAction
from mlagents.trainers.torch.action_log_probs import ActionLogProbs
from mlagents.trainers.settings import NetworkSettings
from mlagents.trainers.settings import NetworkSettings, EncoderType
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.torch.decoders import ValueHeads
from mlagents.trainers.torch.layers import LSTM, LinearEncoder
from mlagents.trainers.torch.encoders import VectorInput
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.trajectory import ObsUtil
from mlagents.trainers.torch.attention import EntityEmbedding, ResidualSelfAttention
from mlagents.trainers.torch.attention import (
EntityEmbedding,
ResidualSelfAttention,
get_zero_entities_mask,
)
from mlagents.trainers.exception import UnityTrainerException


ActivationFunction = Callable[[torch.Tensor], torch.Tensor]
Expand All @@ -25,6 +30,103 @@
EPSILON = 1e-7


class ObservationEncoder(nn.Module):
def __init__(
self,
observation_specs: List[ObservationSpec],
h_size: int,
vis_encode_type: EncoderType,
normalize: bool = False,
):
"""
Returns an ObservationEncoder that can process and encode a set of observations.
Will use an RSA if needed for variable length observations.
"""
super().__init__()
self.processors, self.embedding_sizes = ModelUtils.create_input_processors(
observation_specs, h_size, vis_encode_type, normalize=normalize
)
self.rsa, self.x_self_encoder = ModelUtils.create_residual_self_attention(
self.processors, self.embedding_sizes, h_size
)
if self.rsa is not None:
total_enc_size = sum(self.embedding_sizes) + h_size
else:
total_enc_size = sum(self.embedding_sizes)
self.normalize = normalize
self._total_enc_size = total_enc_size

@property
def total_enc_size(self) -> int:
"""
Returns the total encoding size for this ObservationEncoder.
"""
return self._total_enc_size

def update_normalization(self, buffer: AgentBuffer) -> None:
obs = ObsUtil.from_buffer(buffer, len(self.processors))
for vec_input, enc in zip(obs, self.processors):
if isinstance(enc, VectorInput):
enc.update_normalization(torch.as_tensor(vec_input))

def copy_normalization(self, other_encoder: "ObservationEncoder") -> None:
if self.normalize:
for n1, n2 in zip(self.processors, other_encoder.processors):
if isinstance(n1, VectorInput) and isinstance(n2, VectorInput):
n1.copy_normalization(n2)

def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
"""
Encode observations using a list of processors and an RSA.
:param inputs: List of Tensors corresponding to a set of obs.
:param processors: a ModuleList of the input processors to be applied to these obs.
:param rsa: Optionally, an RSA to use for variable length obs.
:param x_self_encoder: Optionally, an encoder to use for x_self (in this case, the non-variable inputs.).
"""
encodes = []
var_len_processor_inputs: List[Tuple[nn.Module, torch.Tensor]] = []

for idx, processor in enumerate(self.processors):
if not isinstance(processor, EntityEmbedding):
# The input can be encoded without having to process other inputs
obs_input = inputs[idx]
processed_obs = processor(obs_input)
encodes.append(processed_obs)
else:
var_len_processor_inputs.append((processor, inputs[idx]))
if len(encodes) != 0:
encoded_self = torch.cat(encodes, dim=1)
input_exist = True
else:
input_exist = False
if len(var_len_processor_inputs) > 0 and self.rsa is not None:
# Some inputs need to be processed with a variable length encoder
masks = get_zero_entities_mask([p_i[1] for p_i in var_len_processor_inputs])
embeddings: List[torch.Tensor] = []
processed_self = (
self.x_self_encoder(encoded_self)
if input_exist and self.x_self_encoder is not None
else None
)
for processor, var_len_input in var_len_processor_inputs:
embeddings.append(processor(processed_self, var_len_input))
qkv = torch.cat(embeddings, dim=1)
attention_embedding = self.rsa(qkv, masks)
if not input_exist:
encoded_self = torch.cat([attention_embedding], dim=1)
input_exist = True
else:
encoded_self = torch.cat([encoded_self, attention_embedding], dim=1)

if not input_exist:
raise UnityTrainerException(
"The trainer was unable to process any of the provided inputs. "
"Make sure the trained agents has at least one sensor attached to them."
)

return encoded_self


class NetworkBody(nn.Module):
def __init__(
self,
Expand All @@ -41,22 +143,13 @@ def __init__(
if network_settings.memory is not None
else 0
)

self.processors, self.embedding_sizes = ModelUtils.create_input_processors(
self.observation_encoder = ObservationEncoder(
observation_specs,
self.h_size,
network_settings.vis_encode_type,
normalize=self.normalize,
)

self.rsa, self.x_self_encoder = ModelUtils.create_residual_self_attention(
self.processors, self.embedding_sizes, self.h_size
self.normalize,
)
if self.rsa is not None:
total_enc_size = sum(self.embedding_sizes) + self.h_size
else:
total_enc_size = sum(self.embedding_sizes)

total_enc_size = self.observation_encoder.total_enc_size
total_enc_size += encoded_act_size
self.linear_encoder = LinearEncoder(
total_enc_size, network_settings.num_layers, self.h_size
Expand All @@ -68,16 +161,10 @@ def __init__(
self.lstm = None # type: ignore

def update_normalization(self, buffer: AgentBuffer) -> None:
obs = ObsUtil.from_buffer(buffer, len(self.processors))
for vec_input, enc in zip(obs, self.processors):
if isinstance(enc, VectorInput):
enc.update_normalization(torch.as_tensor(vec_input))
self.observation_encoder.update_normalization(buffer)

def copy_normalization(self, other_network: "NetworkBody") -> None:
if self.normalize:
for n1, n2 in zip(self.processors, other_network.processors):
if isinstance(n1, VectorInput) and isinstance(n2, VectorInput):
n1.copy_normalization(n2)
self.observation_encoder.copy_normalization(other_network.observation_encoder)

@property
def memory_size(self) -> int:
Expand All @@ -90,9 +177,7 @@ def forward(
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
encoded_self = ModelUtils.encode_observations(
inputs, self.processors, self.rsa, self.x_self_encoder
)
encoded_self = self.observation_encoder(inputs)
if actions is not None:
encoded_self = torch.cat([encoded_self, actions], dim=1)
encoding = self.linear_encoder(encoded_self)
Expand Down Expand Up @@ -127,27 +212,18 @@ def __init__(
if network_settings.memory is not None
else 0
)
self.processors, _input_size = ModelUtils.create_input_processors(
self.action_spec = action_spec
self.observation_encoder = ObservationEncoder(
observation_specs,
self.h_size,
network_settings.vis_encode_type,
normalize=self.normalize,
)
self.action_spec = action_spec
# This RSA and input are for variable length obs, not for multi-agentt.
(
self.input_rsa,
self.input_x_self_encoder,
) = ModelUtils.create_residual_self_attention(
self.processors, _input_size, self.h_size
self.normalize,
)
if self.input_rsa is not None:
_input_size.append(self.h_size)

# Modules for multi-agent self-attention
obs_only_ent_size = sum(_input_size)
obs_only_ent_size = self.observation_encoder.total_enc_size
q_ent_size = (
sum(_input_size)
obs_only_ent_size
+ sum(self.action_spec.discrete_branches)
+ self.action_spec.continuous_size
)
Expand All @@ -173,16 +249,10 @@ def memory_size(self) -> int:
return self.lstm.memory_size if self.use_lstm else 0

def update_normalization(self, buffer: AgentBuffer) -> None:
obs = ObsUtil.from_buffer(buffer, len(self.processors))
for vec_input, enc in zip(obs, self.processors):
if isinstance(enc, VectorInput):
enc.update_normalization(torch.as_tensor(vec_input))
self.observation_encoder.update_normalization(buffer)

def copy_normalization(self, other_network: "MultiAgentNetworkBody") -> None:
if self.normalize:
for n1, n2 in zip(self.processors, other_network.processors):
if isinstance(n1, VectorInput) and isinstance(n2, VectorInput):
n1.copy_normalization(n2)
self.observation_encoder.copy_normalization(other_network.observation_encoder)

def _get_masks_from_nans(self, obs_tensors: List[torch.Tensor]) -> torch.Tensor:
"""
Expand Down Expand Up @@ -243,9 +313,7 @@ def forward(
obs_attn_mask = self._get_masks_from_nans(obs)
obs = self._copy_and_remove_nans_from_obs(obs, obs_attn_mask)
for inputs, action in zip(obs, actions):
encoded = ModelUtils.encode_observations(
inputs, self.processors, self.input_rsa, self.input_x_self_encoder
)
encoded = self.observation_encoder(inputs)
cat_encodes = [
encoded,
action.to_flat(self.action_spec.discrete_branches),
Expand All @@ -260,9 +328,7 @@ def forward(
obs_only_attn_mask = self._get_masks_from_nans(obs_only)
obs_only = self._copy_and_remove_nans_from_obs(obs_only, obs_only_attn_mask)
for inputs in obs_only:
encoded = ModelUtils.encode_observations(
inputs, self.processors, self.input_rsa, self.input_x_self_encoder
)
encoded = self.observation_encoder(inputs)
concat_encoded_obs.append(encoded)
g_inp = torch.stack(concat_encoded_obs, dim=1)
self_attn_masks.append(obs_only_attn_mask)
Expand Down Expand Up @@ -530,10 +596,10 @@ def forward(
end = 0
vis_index = 0
var_len_index = 0
for i, enc in enumerate(self.network_body.processors):
for i, enc in enumerate(self.network_body.observation_encoder.processors):
if isinstance(enc, VectorInput):
# This is a vec_obs
vec_size = self.network_body.embedding_sizes[i]
vec_size = self.network_body.observation_encoder.embedding_sizes[i]
end = start + vec_size
inputs.append(concatenated_vec_obs[:, start:end])
start = end
Expand Down
63 changes: 1 addition & 62 deletions ml-agents/mlagents/trainers/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,7 @@
VectorInput,
)
from mlagents.trainers.settings import EncoderType, ScheduleType
from mlagents.trainers.torch.attention import (
EntityEmbedding,
ResidualSelfAttention,
get_zero_entities_mask,
)
from mlagents.trainers.torch.attention import EntityEmbedding, ResidualSelfAttention
from mlagents.trainers.exception import UnityTrainerException
from mlagents_envs.base_env import ObservationSpec, DimensionProperty

Expand Down Expand Up @@ -372,63 +368,6 @@ def create_residual_self_attention(
rsa = ResidualSelfAttention(hidden_size, entity_num_max)
return rsa, x_self_encoder

@staticmethod
def encode_observations(
inputs: List[torch.Tensor],
processors: nn.ModuleList,
rsa: Optional[ResidualSelfAttention],
x_self_encoder: Optional[LinearEncoder],
) -> torch.Tensor:
"""
Helper method to encode observations using a listt of processors and an RSA.
:param inputs: List of Tensors corresponding to a set of obs.
:param processors: a ModuleList of the input processors to be applied to these obs.
:param rsa: Optionally, an RSA to use for variable length obs.
:param x_self_encoder: Optionally, an encoder to use for x_self (in this case, the non-variable inputs.).
"""
encodes = []
var_len_processor_inputs: List[Tuple[nn.Module, torch.Tensor]] = []

for idx, processor in enumerate(processors):
if not isinstance(processor, EntityEmbedding):
# The input can be encoded without having to process other inputs
obs_input = inputs[idx]
processed_obs = processor(obs_input)
encodes.append(processed_obs)
else:
var_len_processor_inputs.append((processor, inputs[idx]))
if len(encodes) != 0:
encoded_self = torch.cat(encodes, dim=1)
input_exist = True
else:
input_exist = False
if len(var_len_processor_inputs) > 0 and rsa is not None:
# Some inputs need to be processed with a variable length encoder
masks = get_zero_entities_mask([p_i[1] for p_i in var_len_processor_inputs])
embeddings: List[torch.Tensor] = []
processed_self = (
x_self_encoder(encoded_self)
if input_exist and x_self_encoder is not None
else None
)
for processor, var_len_input in var_len_processor_inputs:
embeddings.append(processor(processed_self, var_len_input))
qkv = torch.cat(embeddings, dim=1)
attention_embedding = rsa(qkv, masks)
if not input_exist:
encoded_self = torch.cat([attention_embedding], dim=1)
input_exist = True
else:
encoded_self = torch.cat([encoded_self, attention_embedding], dim=1)

if not input_exist:
raise UnityTrainerException(
"The trainer was unable to process any of the provided inputs. "
"Make sure the trained agents has at least one sensor attached to them."
)

return encoded_self

@staticmethod
def trust_region_value_loss(
values: Dict[str, torch.Tensor],
Expand Down