-
Notifications
You must be signed in to change notification settings - Fork 4.3k
[feature] Add experimental PyTorch support #4335
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
132 commits
Select commit
Hold shift + click to select a range
017c3cb
Begin porting work
awjuliani d99fc74
Add ResNet and distributions
awjuliani c981a81
Merge remote-tracking branch 'origin/master' into develop-add-fire
awjuliani 6dfb8fa
Merge remote-tracking branch 'origin/master' into develop-add-fire
awjuliani a492a9f
Dynamically construct actor and critic
awjuliani 5e6f4ae
Initial optimizer port
awjuliani 7e46bc5
Refactoring policy and optimizer
awjuliani a3a1c0f
Resolving a few bugs
awjuliani 652b399
Share more code between tf and torch policies
awjuliani da49aaa
Slightly closer to running model
awjuliani 1ae28be
Training runs, but doesn’t actually work
awjuliani b68eb20
Fix a couple additional bugs
awjuliani 5e39d84
Add conditional sigma for distribution
awjuliani a0d6823
Fix normalization
c807190
Merge remote-tracking branch 'origin/develop-add-fire-debug' into dev…
awjuliani e2d7fee
Support discrete actions as well
awjuliani f5b28d3
Continuous and discrete now train
awjuliani 50f5cc1
Mulkti-discrete now working
awjuliani 8c10cd3
Visual observations now train as well
awjuliani 8445661
Merge remote-tracking branch 'origin/master' into develop-add-fire
awjuliani deb6e92
GRU in-progress and dynamic cnns
awjuliani 57486ab
Fix for memories
awjuliani f6d5df5
Remove unused arg
awjuliani 5521670
Combine actor and critic classes. Initial export.
awjuliani 9b9e783
Support tf and pytorch alongside one another
awjuliani e98def6
Prepare model for onnx export
awjuliani d6d69ad
Merge remote-tracking branch 'origin/master' into develop-add-fire
awjuliani 2c6daac
Use LSTM and fix a few merge errors
awjuliani 411d0c4
Merge remote-tracking branch 'origin/master' into develop-add-fire
awjuliani 8b36db0
Fix bug in probs calculation
awjuliani ff72b3e
Optimize np -> tensor operations
awjuliani ee9fbd1
Time action sample function
awjuliani 8f92145
Small performance improvement during inference
awjuliani 4eead36
Merge remote-tracking branch 'origin/master' into develop-add-fire
awjuliani b3d1201
Merge master
awjuliani 892f385
ONNX exporting
awjuliani d12c053
Fix some issues with pdf
awjuliani 742d322
Fix bug in pdf function
awjuliani 509a858
Fix ResNet
awjuliani 2a22e17
Remove double setting
awjuliani 3442de5
Fix for discrete actions (#4181)
aadaca9
Fix discrete actions and GridWorld
a303586
Remove print statement
b3ca0c9
Convert List[np.ndarray] to np.ndarray before using torch.as_tensor (…
088cbe9
Develop add fire exp framework (#4213)
vincentpierre da3a7f8
reformating experiment_torch.py
vincentpierre 5d5c4ea
Pytorch port of SAC (#4219)
4214ec8
Update add-fire to latest master, including Policy refactor (#4263)
38c3dd1
[refactor] Refactor normalizers and encoders (#4275)
13b78e7
fix onnx save path and output_name
254f83b
add Saver class (only TF working)
43e32f6
Merge branch 'develop-add-fire-checkpoint' of https://github.com/Unit…
7756a87
fix pytorch checkpointing. add tensors in Normalizer as parameter
dbf2daf
remove print
02f1916
move tf and add torch model serialization
d57b830
remove
b62a1cd
remove unused
8bb30b1
add sac checkpoint
cce8227
small improvements
2da4d88
small improvements
6e8ed26
remove print
76ef088
move checkpoint_path logic to saver
17bacbb
[refactor] Refactor Actor and Critic classes (#4287)
ea93224
fix onnx input
1ff782a
fix formatting and test
949aa1f
[bug-fix] Fix non-LSTM SeparateActorCritic (#4306)
08b810a
small improvements
560f937
small improvement
02e35fd
[bug-fix] Fix error with discrete probs (#4309)
9d0fad2
[tests] Add tests for core PyTorch files (#4292)
6f9bd88
Merge branch 'develop-add-fire' into develop-add-fire-checkpoint
19c9ff0
[feature] Fix TF tests, add --torch CLI option, allow run TF without …
749acff
Test fixes on add-fire (#4317)
vincentpierre d33ad07
fix tests
ace4394
Add components directory and init (#4320)
andrewcoh 4759d1f
[add-fire] Halve Gaussian entropy (#4319)
c2b0074
[add-fire] Add learning rate and beta/epsilon decay to PyTorch (#4318)
143876b
Added Reward Providers for Torch (#4280)
vincentpierre 7b2c2f9
Fix discrete export (#4322)
dongruoping 9430fb3
[add-fire] Fix CategoricalDistInstance test and replace `range` with …
7c3ff1d
Develop add fire layers (#4321)
vincentpierre f54bf42
Merge branch 'master' into develop-add-fire-mm
e1dce72
fixing typo
vincentpierre 9913e71
[add-fire] Merge post-0.19.0 master into add-fire (#4328)
d9e6198
Revert "[add-fire] Merge post-0.19.0 master into add-fire (#4328)" (#…
1bae38e
More comments and Made ResNetBlock (#4329)
vincentpierre ff667e7
Merge pull request #4331 from Unity-Technologies/develop-add-fire-mm2
680c823
Merge branch 'develop-add-fire' into develop-add-fire-checkpoint
b6bc80d
update saver interface and add tests
42f24b3
update
9874a35
Fixed the reporting of the discriminator loss (#4348)
vincentpierre a23669d
Fix ONNX import for continuous
e51db51
fix export input names
83e17bb
Behavioral Cloning Pytorch (#4293)
andrewcoh b706bfe
Merge branch 'develop-add-fire-checkpoint' of https://github.com/Unit…
6d19f58
fix export input name
5ce6272
[add-fire] Add LSTM to SAC, LSTM fixes and initializations (#4324)
9d95298
add comments
003f4a6
Merge branch 'develop-add-fire' into develop-add-fire-checkpoint
cb87d78
Merge branch 'master' into develop-add-fire-mm3
06b2106
fix bc tests
61f3aca
Merge branch 'develop-add-fire-mm3' into develop-add-fire-checkpoint
4d7d118
change brain_name to behavior_name
de0265e
Merge master and add Saver class for save/load checkpoints
dongruoping 291091a
reverting Project settings
vincentpierre d37960c
[add-fire] Fix masked mean for 2d tensors (#4364)
c3fae3a
Removing the experiment script from add fire (#4373)
vincentpierre 71e7b17
[add-fire] Add tests and fix issues with Policy (#4372)
f9273bb
Pytorch ghost trainer (#4370)
andrewcoh 23e8d72
add test_simple_rl tests to torch
andrewcoh 6635413
revert tests
andrewcoh 1d89489
Fix of the test for multi visual input
vincentpierre 48e77c6
Make reset block submodule
6e75dd1
fix export input_name
7660a90
[add-fire] Memory class abstraction (#4375)
4db512b
make visual input channel first for export
bd41761
Merge branch 'develop-add-fire' into develop-add-fire-export
47212e5
Don't use torch.split in LSTM
09c2dc3
Add fire to test_simple_rl.py (#4378)
andrewcoh b22f412
Merge branch 'develop-add-fire' of github.com:Unity-Technologies/ml-a…
269a4c8
reverting unity_to_external_pb2_grpc.py
vincentpierre 3d7b809
remove duplicate of curr documentation
andrewcoh 1940d96
Revert "remove duplicate of curr documentation"
andrewcoh 9406624
remove duplicated curriculum doc (#4386)
andrewcoh 0a8b5e0
Fixed discrete models
e6eb502
Always export one Action tensor (#4388)
6f46b30
[add-fire] Revert unneeded changes back to master (#4389)
435d226
add comment
1a15577
fix test
38c1007
Fix export
dongruoping ddcf078
add fire clean up docstrings in create policies (#4391)
andrewcoh e93c746
[add-fire] Update changelog (#4397)
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
from typing import Dict, Optional, Tuple, List | ||
import torch | ||
import numpy as np | ||
|
||
from mlagents.trainers.buffer import AgentBuffer | ||
from mlagents.trainers.trajectory import SplitObservations | ||
from mlagents.trainers.torch.components.bc.module import BCModule | ||
from mlagents.trainers.torch.components.reward_providers import create_reward_provider | ||
|
||
from mlagents.trainers.policy.torch_policy import TorchPolicy | ||
from mlagents.trainers.optimizer import Optimizer | ||
from mlagents.trainers.settings import TrainerSettings | ||
from mlagents.trainers.torch.utils import ModelUtils | ||
|
||
|
||
class TorchOptimizer(Optimizer): # pylint: disable=W0223 | ||
def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings): | ||
super().__init__() | ||
self.policy = policy | ||
self.trainer_settings = trainer_settings | ||
self.update_dict: Dict[str, torch.Tensor] = {} | ||
self.value_heads: Dict[str, torch.Tensor] = {} | ||
self.memory_in: torch.Tensor = None | ||
self.memory_out: torch.Tensor = None | ||
self.m_size: int = 0 | ||
self.global_step = torch.tensor(0) | ||
self.bc_module: Optional[BCModule] = None | ||
self.create_reward_signals(trainer_settings.reward_signals) | ||
if trainer_settings.behavioral_cloning is not None: | ||
self.bc_module = BCModule( | ||
self.policy, | ||
trainer_settings.behavioral_cloning, | ||
policy_learning_rate=trainer_settings.hyperparameters.learning_rate, | ||
default_batch_size=trainer_settings.hyperparameters.batch_size, | ||
default_num_epoch=3, | ||
) | ||
|
||
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: | ||
pass | ||
|
||
def create_reward_signals(self, reward_signal_configs): | ||
""" | ||
Create reward signals | ||
:param reward_signal_configs: Reward signal config. | ||
""" | ||
for reward_signal, settings in reward_signal_configs.items(): | ||
# Name reward signals by string in case we have duplicates later | ||
self.reward_signals[reward_signal.value] = create_reward_provider( | ||
reward_signal, self.policy.behavior_spec, settings | ||
) | ||
|
||
def get_trajectory_value_estimates( | ||
self, batch: AgentBuffer, next_obs: List[np.ndarray], done: bool | ||
) -> Tuple[Dict[str, np.ndarray], Dict[str, float]]: | ||
vector_obs = [ModelUtils.list_to_tensor(batch["vector_obs"])] | ||
if self.policy.use_vis_obs: | ||
visual_obs = [] | ||
for idx, _ in enumerate( | ||
self.policy.actor_critic.network_body.visual_encoders | ||
): | ||
visual_ob = ModelUtils.list_to_tensor(batch["visual_obs%d" % idx]) | ||
visual_obs.append(visual_ob) | ||
else: | ||
visual_obs = [] | ||
|
||
memory = torch.zeros([1, 1, self.policy.m_size]) | ||
|
||
vec_vis_obs = SplitObservations.from_observations(next_obs) | ||
next_vec_obs = [ | ||
ModelUtils.list_to_tensor(vec_vis_obs.vector_observations).unsqueeze(0) | ||
] | ||
next_vis_obs = [ | ||
ModelUtils.list_to_tensor(_vis_ob).unsqueeze(0) | ||
for _vis_ob in vec_vis_obs.visual_observations | ||
] | ||
|
||
value_estimates, next_memory = self.policy.actor_critic.critic_pass( | ||
vector_obs, visual_obs, memory, sequence_length=batch.num_experiences | ||
) | ||
|
||
next_value_estimate, _ = self.policy.actor_critic.critic_pass( | ||
next_vec_obs, next_vis_obs, next_memory, sequence_length=1 | ||
) | ||
|
||
for name, estimate in value_estimates.items(): | ||
value_estimates[name] = estimate.detach().cpu().numpy() | ||
next_value_estimate[name] = next_value_estimate[name].detach().cpu().numpy() | ||
|
||
if done: | ||
for k in next_value_estimate: | ||
if not self.reward_signals[k].ignore_done: | ||
next_value_estimate[k] = 0.0 | ||
|
||
return value_estimates, next_value_estimate |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've tried to remove this, this delta isn't picked up by git 👿