Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(rjy): add crowdsim env and related configs #208

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 214 additions & 0 deletions lzero/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import math
from typing import Optional, Tuple
from dataclasses import dataclass
import logging
import itertools
import numpy as np
import torch
import torch.nn as nn
Expand Down Expand Up @@ -271,6 +273,218 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
return self.fc_representation(x)

class RGCNLayer(nn.Module):
"""
Overview:
Relational graph convolutional network layer.
"""
def __init__(
self,
robot_state_dim,
human_state_dim,
similarity_function,
num_layer = 2,
X_dim = 32,
layerwise_graph = False,
skip_connection = True,
wr_dims = [64, 32], # the last dim should equal to X_dim
wh_dims = [64, 32], # the last dim should equal to X_dim
final_state_dim = 32, # should equal to X_dim
norm_type= 'BN',
last_linear_layer_init_zero=True,
activation: Optional[nn.Module] = nn.ReLU(inplace=True),
):
super().__init__()

# design choice
# 'gaussian', 'embedded_gaussian', 'cosine', 'cosine_softmax', 'concatenation'
self.similarity_function = similarity_function
self.robot_state_dim = robot_state_dim
self.human_state_dim = human_state_dim
self.num_layer = num_layer
self.X_dim = X_dim
self.layerwise_graph = layerwise_graph
self.skip_connection = skip_connection

logging.info('Similarity_func: {}'.format(self.similarity_function))
logging.info('Layerwise_graph: {}'.format(self.layerwise_graph))
logging.info('Skip_connection: {}'.format(self.skip_connection))
logging.info('Number of layers: {}'.format(self.num_layer))

self.w_r = MLP(
in_channels=robot_state_dim,
hidden_channels=wr_dims[0],
out_channels=wr_dims[1],
layer_num=num_layer,
activation=activation,
norm_type=norm_type,
last_linear_layer_init_zero=last_linear_layer_init_zero,
) # inputs,64,32
self.w_h = MLP(
in_channels=human_state_dim,
hidden_channels=wh_dims[0],
out_channels=wh_dims[1],
layer_num=num_layer,
activation=activation,
norm_type=norm_type,
last_linear_layer_init_zero=last_linear_layer_init_zero,
) # inputs,64,32

if self.similarity_function == 'embedded_gaussian':
self.w_a = nn.Parameter(torch.randn(self.X_dim, self.X_dim))
elif self.similarity_function == 'concatenation':
# TODO: fix the dim size
self.w_a = MLP(
in_channels=2 * X_dim,
hidden_channels=2 * X_dim,
out_channels=1,
layer_num=1,
)

embedding_dim = self.X_dim
self.Ws = torch.nn.ParameterList()
for i in range(self.num_layer):
if i == 0:
self.Ws.append(nn.Parameter(torch.randn(self.X_dim, embedding_dim)))
elif i == self.num_layer - 1:
self.Ws.append(nn.Parameter(torch.randn(embedding_dim, final_state_dim)))
else:
self.Ws.append(nn.Parameter(torch.randn(embedding_dim, embedding_dim)))

# TODO: for visualization
self.A = None

def compute_similarity_matrix(self, X):
if self.similarity_function == 'embedded_gaussian':
A = torch.matmul(torch.matmul(X, self.w_a), X.permute(0, 2, 1))
normalized_A = nn.functional.softmax(A, dim=2)
elif self.similarity_function == 'gaussian':
A = torch.matmul(X, X.permute(0, 2, 1))
normalized_A = nn.functional.softmax(A, dim=2)
elif self.similarity_function == 'cosine':
A = torch.matmul(X, X.permute(0, 2, 1))
magnitudes = torch.norm(A, dim=2, keepdim=True)
norm_matrix = torch.matmul(magnitudes, magnitudes.permute(0, 2, 1))
normalized_A = torch.div(A, norm_matrix)
elif self.similarity_function == 'cosine_softmax':
A = torch.matmul(X, X.permute(0, 2, 1))
magnitudes = torch.norm(A, dim=2, keepdim=True)
norm_matrix = torch.matmul(magnitudes, magnitudes.permute(0, 2, 1))
normalized_A = nn.functional.softmax(torch.div(A, norm_matrix), dim=2)
elif self.similarity_function == 'concatenation':
indices = [pair for pair in itertools.product(list(range(X.size(1))), repeat=2)]
selected_features = torch.index_select(X, dim=1, index=torch.LongTensor(indices).reshape(-1))
pairwise_features = selected_features.reshape((-1, X.size(1) * X.size(1), X.size(2) * 2))
A = self.w_a(pairwise_features).reshape(-1, X.size(1), X.size(1))
normalized_A = A
elif self.similarity_function == 'squared':
A = torch.matmul(X, X.permute(0, 2, 1))
squared_A = A * A
normalized_A = squared_A / torch.sum(squared_A, dim=2, keepdim=True)
elif self.similarity_function == 'equal_attention':
normalized_A = (torch.ones(X.size(1), X.size(1)) / X.size(1)).expand(X.size(0), X.size(1), X.size(1))
elif self.similarity_function == 'diagonal':
normalized_A = (torch.eye(X.size(1), X.size(1))).expand(X.size(0), X.size(1), X.size(1))
else:
raise NotImplementedError

return normalized_A

def forward(self, state):
robot_states = state['robot_state']
human_states = state['human_state']

# compute feature matrix X
robot_state_embedings = self.w_r(robot_states) # batch x num x embedding_dim
human_state_embedings = self.w_h(human_states)
X = torch.cat([robot_state_embedings, human_state_embedings], dim=1)

# compute matrix A
if not self.layerwise_graph:
normalized_A = self.compute_similarity_matrix(X)
self.A = normalized_A[0, :, :].data.cpu().numpy() # total_num x total_num

# next_H = H = X

H = X.contiguous().clone()
next_H = H.contiguous().clone() # batch x total_num x embedding_dim
for i in range(self.num_layer): # 2
if self.layerwise_graph: # False
A = self.compute_similarity_matrix(H)
next_H = nn.functional.relu(torch.matmul(torch.matmul(A, H), self.Ws[i]))
else: # (A x H) x W_i
next_H = nn.functional.relu(torch.matmul(torch.matmul(normalized_A, H), self.Ws[i]))

if self.skip_connection:
# next_H += H
next_H = next_H + H
H = next_H.contiguous().clone()

return next_H

class RepresentationNetworkGCN(nn.Module):

def __init__(
self,
robot_observation_shape: tuple,
human_observation_shape: tuple,
hidden_channels: int = 64,
layer_num: int = 2,
activation: Optional[nn.Module] = nn.ReLU(inplace=True),
last_linear_layer_init_zero: bool = True,
norm_type: Optional[str] = 'BN',
) -> torch.Tensor:
"""
Overview:

Arguments:
- robot_observation_shape (:obj:`tuple`): The shape of robot observation space, e.g. (2, 4).
- human_observation_shape (:obj:`tuple`): The shape of human observation space, e.g. (59, 4).
- hidden_channels (:obj:`int`): The channel of output hidden state.
- activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(). \
Use the inplace operation to speed up.
- last_linear_layer_init_zero (:obj:`bool`): Whether to initialize the last linear layer with zeros, \
which can provide stable zero outputs in the beginning, defaults to True.
- norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'.
"""
super().__init__()
# self.fc_representation = MLP(
# in_channels=observation_shape,
# hidden_channels=hidden_channels,
# out_channels=hidden_channels,
# layer_num=layer_num,
# activation=activation,
# norm_type=norm_type,
# # don't use activation and norm in the last layer of representation network is important for convergence.
# output_activation=False,
# output_norm=False,
# # last_linear_layer_init_zero=True is beneficial for convergence speed.
# last_linear_layer_init_zero=True,
# )
self.rgl_representation = RGCNLayer(
robot_state_dim=robot_observation_shape[-1],
human_state_dim=human_observation_shape[-1],
similarity_function='embedded_gaussian',
num_layer=layer_num,
X_dim=hidden_channels,
layerwise_graph=False,
skip_connection=True,
wr_dims=[64, hidden_channels],
wh_dims=[64, hidden_channels],
final_state_dim=hidden_channels,
# for mlp
norm_type=norm_type,
last_linear_layer_init_zero=last_linear_layer_init_zero,
activation=activation,
)

def forward(self, x: dict) -> torch.Tensor:
"""
Shapes:
- x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size, N is the length of vector observation.
- output (:obj:`torch.Tensor`): :math:`(B, hidden_channels)`, where B is batch size.
"""
return self.rgl_representation(x)

class PredictionNetwork(nn.Module):

Expand Down
88 changes: 65 additions & 23 deletions lzero/worker/muzero_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,12 +571,28 @@ def collect(self,
self._env_info[env_id]['time'] += self._timer.value + interaction_duration
if timestep.done:
reward = timestep.info['eval_episode_return']
info = {
'reward': reward,
'time': self._env_info[env_id]['time'],
'step': self._env_info[env_id]['step'],
'visit_entropy': visit_entropies_lst[env_id] / eps_steps_lst[env_id],
}
if timestep.info.get('performance_info') is not None:
nighood marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

description增加这个PR的功能介绍,以及实验的benchmark结果

mean_aoi = timestep.info['performance_info']['mean_aoi']
mean_energy_consumption = timestep.info['performance_info']['mean_energy_consumption']
collected_data_amount = timestep.info['performance_info']['collected_data_amount']
human_coverage = timestep.info['performance_info']['human_coverage']
info = {
'reward': reward,
'time': self._env_info[env_id]['time'],
'step': self._env_info[env_id]['step'],
'visit_entropy': visit_entropies_lst[env_id] / eps_steps_lst[env_id],
'mean_aoi': mean_aoi,
'mean_energy_consumption': mean_energy_consumption,
'collected_data_amount': collected_data_amount,
'human_coverage': human_coverage,
}
else:
info = {
'reward': reward,
'time': self._env_info[env_id]['time'],
'step': self._env_info[env_id]['step'],
'visit_entropy': visit_entropies_lst[env_id] / eps_steps_lst[env_id],
}
if self.policy_config.gumbel_algo:
info['completed_value'] = completed_value_lst[env_id] / eps_steps_lst[env_id]
collected_episode += 1
Expand Down Expand Up @@ -711,23 +727,49 @@ def _output_log(self, train_iter: int) -> None:
if self.policy_config.gumbel_algo:
completed_value = [d['completed_value'] for d in self._episode_info]
self._total_duration += duration
info = {
'episode_count': episode_count,
'envstep_count': envstep_count,
'avg_envstep_per_episode': envstep_count / episode_count,
'avg_envstep_per_sec': envstep_count / duration,
'avg_episode_per_sec': episode_count / duration,
'collect_time': duration,
'reward_mean': np.mean(episode_reward),
'reward_std': np.std(episode_reward),
'reward_max': np.max(episode_reward),
'reward_min': np.min(episode_reward),
'total_envstep_count': self._total_envstep_count,
'total_episode_count': self._total_episode_count,
'total_duration': self._total_duration,
'visit_entropy': np.mean(visit_entropy),
# 'each_reward': episode_reward,
}
if self._episode_info[0].get('mean_aoi') is not None:
nighood marked this conversation as resolved.
Show resolved Hide resolved
episode_aoi = [d['mean_aoi'] for d in self._episode_info]
episode_energy_consumption = [d['mean_energy_consumption'] for d in self._episode_info]
episode_collected_data_amount = [d['collected_data_amount'] for d in self._episode_info]
episode_human_coverage = [d['human_coverage'] for d in self._episode_info]
info = {
'episode_count': episode_count,
'envstep_count': envstep_count,
'avg_envstep_per_episode': envstep_count / episode_count,
'avg_envstep_per_sec': envstep_count / duration,
'avg_episode_per_sec': episode_count / duration,
'collect_time': duration,
'reward_mean': np.mean(episode_reward),
'reward_std': np.std(episode_reward),
'reward_max': np.max(episode_reward),
'reward_min': np.min(episode_reward),
'total_envstep_count': self._total_envstep_count,
'total_episode_count': self._total_episode_count,
'total_duration': self._total_duration,
'visit_entropy': np.mean(visit_entropy),
'episode_mean_aoi': np.mean(episode_aoi),
'episode_mean_energy_consumption': np.mean(episode_energy_consumption),
'episode_mean_collected_data_amount': np.mean(episode_collected_data_amount),
'episode_mean_human_coverage': np.mean(episode_human_coverage),
}
else:
info = {
'episode_count': episode_count,
'envstep_count': envstep_count,
'avg_envstep_per_episode': envstep_count / episode_count,
'avg_envstep_per_sec': envstep_count / duration,
'avg_episode_per_sec': episode_count / duration,
'collect_time': duration,
'reward_mean': np.mean(episode_reward),
'reward_std': np.std(episode_reward),
'reward_max': np.max(episode_reward),
'reward_min': np.min(episode_reward),
'total_envstep_count': self._total_envstep_count,
'total_episode_count': self._total_episode_count,
'total_duration': self._total_duration,
'visit_entropy': np.mean(visit_entropy),
# 'each_reward': episode_reward,
}
if self.policy_config.gumbel_algo:
info['completed_value'] = np.mean(completed_value)
self._episode_info.clear()
Expand Down
29 changes: 15 additions & 14 deletions lzero/worker/muzero_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,21 +337,22 @@ def eval(
action_mask_dict[env_id] = to_ndarray(obs['action_mask'])
to_play_dict[env_id] = to_ndarray(obs['to_play'])

dones[env_id] = done
if t.done:
# Env reset is done by env_manager automatically.
self._policy.reset([env_id])
reward = t.info['eval_episode_return']
saved_info = {'eval_episode_return': t.info['eval_episode_return']}
if 'episode_info' in t.info:
saved_info.update(t.info['episode_info'])
eval_monitor.update_info(env_id, saved_info)
eval_monitor.update_reward(env_id, reward)
self._logger.info(
"[EVALUATOR]env {} finish episode, final reward: {}, current episode: {}".format(
env_id, eval_monitor.get_latest_reward(env_id), eval_monitor.get_current_episode()
)
dones[env_id] = done
if t.done:
# Env reset is done by env_manager automatically.
self._policy.reset([env_id])
reward = t.info['eval_episode_return']
# 'performance_info' and 'episode_info' only choose one
if 'performance_info' in t.info:
nighood marked this conversation as resolved.
Show resolved Hide resolved
eval_monitor.update_info(env_id, t.info['performance_info'])
elif 'episode_info' in t.info:
eval_monitor.update_info(env_id, t.info['episode_info'])
eval_monitor.update_reward(env_id, reward)
self._logger.info(
"[EVALUATOR]env {} finish episode, final reward: {}, current episode: {}".format(
env_id, eval_monitor.get_latest_reward(env_id), eval_monitor.get_current_episode()
)
)

# reset the finished env and init game_segments
if n_episode > self._env_num:
Expand Down
Empty file added zoo/CrowdSim/__init__.py
Empty file.
Loading