Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add policies arg callback on_episode_step #18119

Merged
merged 38 commits into from
Aug 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
47b5cb0
wip
ericl May 10, 2020
607db7c
lint
ericl May 10, 2020
d3b9c9e
add example
ericl May 10, 2020
e0196fa
Update
ericl May 10, 2020
6c95b5a
update
ericl May 10, 2020
8e40e05
update
ericl May 10, 2020
3de093c
update
ericl May 10, 2020
f53d9d3
torch
ericl May 10, 2020
bef3726
lens
ericl May 11, 2020
3f43b93
lint
ericl May 11, 2020
9590ab4
docs
ericl May 11, 2020
74f51e4
public api
ericl May 11, 2020
f8c5ae6
add fancy unbatch helpers
ericl May 11, 2020
867d73b
add comments
ericl May 11, 2020
a6e4110
get
ericl May 11, 2020
0ec7c90
rename to repeated
ericl May 11, 2020
a059ead
typo
ericl May 11, 2020
5763a8e
fix
ericl May 11, 2020
8611de7
update
ericl May 11, 2020
f79e4ec
Merge remote-tracking branch 'upstream/master' into list-type
ericl May 18, 2020
f6ac13a
fix bad offset in preprocessor
ericl May 18, 2020
4781e32
flake
ericl May 19, 2020
196c299
Merge 4781e32d1db5dc0927f7768eb27bbc8b0b14aec6 into a73c488c74b1e01da…
ericl May 19, 2020
d77cacc
Merge branch 'master' of https://github.com/ray-project/ray into HEAD
jsuarez5341 Jun 5, 2020
0fcf6c1
1. Add utils.spaces.simplex.FlexDict; 2. Add agents.trainer.compute_a…
jsuarez5341 Jun 5, 2020
5dfa0bc
Merge branch 'master' into flexdict
jsuarez5341 Jun 16, 2020
c1d175c
Merge pull request #1 from jsuarez5341/flexdict
jsuarez5341 Jun 16, 2020
b95c061
Refactor flexdict PR to work off of latest master
jsuarez5341 Jun 16, 2020
4337ab9
style rllib/utils/spaces/space_utils.py
jsuarez5341 Jun 16, 2020
c7d916f
Update #8792 to pass linter
jsuarez5341 Jun 17, 2020
0ccd0dd
rewrite flexdict dict comprehension for linter
jsuarez5341 Jun 17, 2020
c64b067
Include policies arg in on_episode_step callback signature
jsuarez5341 Aug 26, 2021
4bc303f
Fix merge conflicts
jsuarez5341 Aug 26, 2021
49e8e5f
Merge pull request #3 from jsuarez5341/ray-project-master
jsuarez5341 Aug 26, 2021
029b38a
fix merge conflict
jsuarez5341 Aug 26, 2021
761226d
Merge branch 'ray-project:master' into master
jsuarez5341 Aug 26, 2021
b30586b
Merge branch 'ray-project:master' into master
jsuarez5341 Aug 26, 2021
9ce6173
wip
sven1977 Aug 27, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions rllib/agents/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def on_episode_step(self,
*,
worker: "RolloutWorker",
base_env: BaseEnv,
policies: Optional[Dict[PolicyID, Policy]] = None,
episode: MultiAgentEpisode,
env_index: Optional[int] = None,
**kwargs) -> None:
Expand All @@ -80,6 +81,9 @@ def on_episode_step(self,
worker (RolloutWorker): Reference to the current rollout worker.
base_env (BaseEnv): BaseEnv running the episode. The underlying
env object can be gotten by calling base_env.get_unwrapped().
policies (Optional[Dict[PolicyID, Policy]]): Mapping of policy id
to policy objects. In single agent mode there will only be a
single "default_policy".
episode (MultiAgentEpisode): Episode object which contains episode
state. You can use the `episode.user_data` dict to store
temporary data, and `episode.custom_metrics` to store custom
Expand Down Expand Up @@ -109,8 +113,9 @@ def on_episode_end(self,
worker (RolloutWorker): Reference to the current rollout worker.
base_env (BaseEnv): BaseEnv running the episode. The underlying
env object can be gotten by calling base_env.get_unwrapped().
policies (dict): Mapping of policy id to policy objects. In single
agent mode there will only be a single "default" policy.
policies (Dict[PolicyID, Policy]): Mapping of policy id to policy
objects. In single agent mode there will only be a single
"default_policy".
episode (MultiAgentEpisode): Episode object which contains episode
state. You can use the `episode.user_data` dict to store
temporary data, and `episode.custom_metrics` to store custom
Expand Down Expand Up @@ -144,7 +149,7 @@ def on_postprocess_trajectory(
agent_id (str): Id of the current agent.
policy_id (str): Id of the current policy for the agent.
policies (dict): Mapping of policy id to policy objects. In single
agent mode there will only be a single "default" policy.
agent mode there will only be a single "default_policy".
postprocessed_batch (SampleBatch): The postprocessed sample batch
for this agent. You can mutate this object to apply your own
trajectory postprocessing.
Expand Down Expand Up @@ -319,13 +324,15 @@ def on_episode_step(self,
*,
worker: "RolloutWorker",
base_env: BaseEnv,
policies: Optional[Dict[PolicyID, Policy]] = None,
episode: MultiAgentEpisode,
env_index: Optional[int] = None,
**kwargs) -> None:
for callback in self._callback_list:
callback.on_episode_step(
worker=worker,
base_env=base_env,
policies=policies,
episode=episode,
env_index=env_index,
**kwargs)
Expand Down
1 change: 1 addition & 0 deletions rllib/evaluation/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,7 @@ def _process_observations(
callbacks.on_episode_step(
worker=worker,
base_env=base_env,
policies=worker.policy_map,
episode=episode,
env_index=env_id)

Expand Down
1 change: 1 addition & 0 deletions rllib/examples/custom_metrics_and_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv,
episode.hist_data["pole_angles"] = []

def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv,
policies: Dict[str, Policy],
episode: MultiAgentEpisode, env_index: int, **kwargs):
# Make sure this episode is ongoing.
assert episode.length > 0, \
Expand Down