Skip to content

[refactor] Remove references to brain_name in policy #4134

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/Training-Configuration-File.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ A few considerations when deciding to use memory:
too large `memory_size` will slow down training.
- Adding a recurrent layer increases the complexity of the neural network, it is
recommended to decrease `num_layers` when using recurrent.
- It is required that `memory_size` be divisible by 4.
- It is required that `memory_size` be divisible by 2.

## Self-Play

Expand Down
18 changes: 1 addition & 17 deletions ml-agents/mlagents/trainers/policy/tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,18 +95,6 @@ def __init__(
if self.network_settings.memory is not None:
self.m_size = self.network_settings.memory.memory_size
self.sequence_length = self.network_settings.memory.sequence_length
if self.m_size == 0:
raise UnityPolicyException(
"The memory size for brain {0} is 0 even "
"though the trainer uses recurrent.".format(brain.brain_name)
)
elif self.m_size % 2 != 0:
raise UnityPolicyException(
"The memory size for brain {0} is {1} "
"but it must be divisible by 2.".format(
brain.brain_name, self.m_size
)
)
self._initialize_tensorflow_references()
self.load = load

Expand Down Expand Up @@ -160,11 +148,7 @@ def _initialize_graph(self):
def _load_graph(self, model_path: str, reset_global_steps: bool = False) -> None:
with self.graph.as_default():
self.saver = tf.train.Saver(max_to_keep=self.keep_checkpoints)
logger.info(
"Loading model for brain {} from {}.".format(
self.brain.brain_name, model_path
)
)
logger.info(f"Loading model from {model_path}.")
ckpt = tf.train.get_checkpoint_state(model_path)
if ckpt is None:
raise UnityPolicyException(
Expand Down
17 changes: 14 additions & 3 deletions ml-agents/mlagents/trainers/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,21 @@ def as_dict(self):

@attr.s(auto_attribs=True)
class NetworkSettings:
@attr.s(auto_attribs=True)
@attr.s
class MemorySettings:
sequence_length: int = 64
memory_size: int = 128
sequence_length: int = attr.ib(default=64)
memory_size: int = attr.ib(default=128)

@memory_size.validator
def _check_valid_memory_size(self, attribute, value):
if value <= 0:
raise TrainerConfigError(
"When using a recurrent network, memory size must be greater than 0."
)
elif value % 2 != 0:
raise TrainerConfigError(
"When using a recurrent network, memory size must be divisible by 2."
)

normalize: bool = False
hidden_units: int = 128
Expand Down
9 changes: 9 additions & 0 deletions ml-agents/mlagents/trainers/tests/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from mlagents.trainers.settings import (
RunOptions,
TrainerSettings,
NetworkSettings,
PPOSettings,
SACSettings,
RewardSignalType,
Expand Down Expand Up @@ -155,6 +156,14 @@ def test_reward_signal_structure():
)


def test_memory_settings_validation():
with pytest.raises(TrainerConfigError):
NetworkSettings.MemorySettings(sequence_length=128, memory_size=63)

with pytest.raises(TrainerConfigError):
NetworkSettings.MemorySettings(sequence_length=128, memory_size=0)


def test_parameter_randomization_structure():
"""
Tests the ParameterRandomizationSettings structure method and all validators.
Expand Down
2 changes: 1 addition & 1 deletion ml-agents/mlagents/trainers/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def export_model(self, name_behavior_id: str) -> None:
Exports the model
"""
policy = self.get_policy(name_behavior_id)
settings = SerializationSettings(policy.model_path, policy.brain.brain_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gross

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With PSankalp's change this stuff should go live in the policy anyways.

settings = SerializationSettings(policy.model_path, self.brain_name)
export_policy_model(settings, policy.graph, policy.sess)

@abc.abstractmethod
Expand Down