Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,17 @@

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

## [2.1.0] - 2026-05-10
### Changed
- Improving the robustness and learning capabilities of on-policy algorithms:
- Sample data from memory using per-epoch mini-batch shuffling
- Sum-reduce policy entropy to prevent collapse into near-deterministic stand still behavior
- Set the random memory `replacement` argument to false by default

### Fixed
- Fix time limits handling of truncation signals in on-policy agents/multi-agents
- Fix the indexing of finished episodes for cumulative rewards and timestep tracking

## [2.0.0] - 2026-04-08

Summary of the most relevant features:
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
if skrl.__version__ != "unknown":
release = version = skrl.__version__
else:
release = version = "2.0.0"
release = version = "2.1.0"

master_doc = "index"

Expand Down
8 changes: 4 additions & 4 deletions docs/source/snippets/multi_agents_basic_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
agent = IPPO(
possible_agents=env.possible_agents,
models=models,
memory=memories, # only required during training
memories=memories, # only required during training
cfg=cfg_agent,
observation_spaces=env.observation_spaces,
state_spaces=env.state_spaces,
Expand Down Expand Up @@ -48,7 +48,7 @@
agent = IPPO(
possible_agents=env.possible_agents,
models=models,
memory=memories, # only required during training
memories=memories, # only required during training
cfg=cfg_agent,
observation_spaces=env.observation_spaces,
state_spaces=env.state_spaces,
Expand Down Expand Up @@ -78,7 +78,7 @@
agent = MAPPO(
possible_agents=env.possible_agents,
models=models,
memory=memories, # only required during training
memories=memories, # only required during training
cfg=cfg_agent,
observation_spaces=env.observation_spaces,
state_spaces=env.state_spaces,
Expand Down Expand Up @@ -108,7 +108,7 @@
agent = MAPPO(
possible_agents=env.possible_agents,
models=models,
memory=memories, # only required during training
memories=memories, # only required during training
cfg=cfg_agent,
observation_spaces=env.observation_spaces,
state_spaces=env.state_spaces,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "skrl"
version = "2.0.0"
version = "2.1.0"
description = "Modular and flexible library for reinforcement learning on PyTorch and JAX"
readme = "README.md"
requires-python = ">=3.10"
Expand Down
10 changes: 5 additions & 5 deletions skrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,8 @@ def __init__(self) -> None:
wp.init()

@staticmethod
def parse_device(device: str | "warp.context.Device" | None) -> "warp.context.Device":
"""Parse the input device and return a :py:class:`~warp.context.Device` instance.
def parse_device(device: str | "warp.Device" | None) -> "warp.Device":
"""Parse the input device and return a :py:class:`~warp.Device` instance.

:param device: Device specification. If the specified device is ``None`` or it cannot be resolved,
the default available device will be returned instead.
Expand All @@ -363,7 +363,7 @@ def parse_device(device: str | "warp.context.Device" | None) -> "warp.context.De
"""
import warp as wp

if isinstance(device, wp.context.Device):
if isinstance(device, wp.Device):
return device
elif isinstance(device, str):
try:
Expand All @@ -373,7 +373,7 @@ def parse_device(device: str | "warp.context.Device" | None) -> "warp.context.De
return wp.get_device()

@property
def device(self) -> "warp.context.Device":
def device(self) -> "warp.Device":
"""Default device.

The default device, unless specified, is ``cuda`` if CUDA is available, ``cpu`` otherwise.
Expand All @@ -382,7 +382,7 @@ def device(self) -> "warp.context.Device":
return self._device

@device.setter
def device(self, device: str | "warp.context.Device") -> None:
def device(self, device: str | "warp.Device") -> None:
self._device = device

@property
Expand Down
44 changes: 27 additions & 17 deletions skrl/agents/jax/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,27 @@


# https://jax.readthedocs.io/en/latest/faq.html#strategy-1-jit-compiled-helper-function
@jax.jit
@functools.partial(jax.jit, static_argnames=("time_limit_bootstrap",))
def _compute_gae(
rewards: jax.Array,
terminated: jax.Array,
truncated: jax.Array,
values: jax.Array,
next_values: jax.Array,
last_values: jax.Array,
discount_factor: float = 0.99,
lambda_coefficient: float = 0.95,
time_limit_bootstrap: bool = False,
) -> jax.Array:
advantage = 0
advantages = jnp.zeros_like(rewards)
not_terminated = jnp.logical_not(terminated)
not_done = jnp.logical_not(jnp.logical_or(terminated, truncated) if time_limit_bootstrap else terminated)
memory_size = rewards.shape[0]

# advantages computation
for i in reversed(range(memory_size)):
next_values = values[i + 1] if i < memory_size - 1 else next_values
next_values = values[i + 1] if i < memory_size - 1 else last_values
advantage = (
rewards[i]
- values[i]
+ discount_factor * not_terminated[i] * (next_values + lambda_coefficient * advantage)
rewards[i] - values[i] + discount_factor * not_done[i] * (next_values + lambda_coefficient * advantage)
)
advantages = advantages.at[i].set(advantage)
# returns computation
Expand Down Expand Up @@ -226,6 +226,7 @@ def init(self, *, trainer_cfg: dict[str, Any] | None = None) -> None:
self.memory.create_tensor(name="actions", size=self.action_space, dtype=jnp.float32)
self.memory.create_tensor(name="rewards", size=1, dtype=jnp.float32)
self.memory.create_tensor(name="terminated", size=1, dtype=jnp.int8)
self.memory.create_tensor(name="truncated", size=1, dtype=jnp.int8)
self.memory.create_tensor(name="log_prob", size=1, dtype=jnp.float32)
self.memory.create_tensor(name="values", size=1, dtype=jnp.float32)
self.memory.create_tensor(name="returns", size=1, dtype=jnp.float32)
Expand Down Expand Up @@ -330,8 +331,15 @@ def record_transition(
rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps)

# time-limit (truncation) bootstrapping
if self.cfg.time_limit_bootstrap:
rewards += self.cfg.discount_factor * self._current_values * truncated
if self.cfg.time_limit_bootstrap and truncated.any():
inputs = {
"observations": self._observation_preprocessor(next_observations),
"states": self._state_preprocessor(next_states),
}
next_values, _ = self.value.act(inputs, role="value")
next_values = self._value_preprocessor(next_values, inverse=True)

rewards += self.cfg.discount_factor * next_values * truncated

# storage transition in memory
self.memory.add_samples(
Expand All @@ -340,6 +348,7 @@ def record_transition(
actions=actions,
rewards=rewards,
terminated=terminated,
truncated=truncated,
log_prob=self._current_log_prob,
values=self._current_values,
)
Expand Down Expand Up @@ -390,19 +399,18 @@ def update(self, *, timestep: int, timesteps: int) -> None:
returns, advantages = _compute_gae(
rewards=self.memory.get_tensor_by_name("rewards"),
terminated=self.memory.get_tensor_by_name("terminated"),
truncated=self.memory.get_tensor_by_name("truncated"),
values=values,
next_values=last_values,
last_values=last_values,
discount_factor=self.cfg.discount_factor,
lambda_coefficient=self.cfg.gae_lambda,
time_limit_bootstrap=self.cfg.time_limit_bootstrap,
)

self.memory.set_tensor_by_name("values", self._value_preprocessor(values, train=True))
self.memory.set_tensor_by_name("returns", self._value_preprocessor(returns, train=True))
self.memory.set_tensor_by_name("advantages", advantages)

# sample mini-batches from memory
sampled_batches = self.memory.sample_all(names=self._tensors_names, mini_batches=self.cfg.mini_batches)

cumulative_policy_loss = 0
cumulative_entropy_loss = 0
cumulative_value_loss = 0
Expand All @@ -417,7 +425,9 @@ def update(self, *, timestep: int, timesteps: int) -> None:
sampled_log_prob,
sampled_returns,
sampled_advantages,
) in sampled_batches:
) in self.memory.sample(
names=self._tensors_names, batch_size=len(self.memory), mini_batches=self.cfg.mini_batches
):

inputs = {
"observations": self._observation_preprocessor(sampled_observations, train=True),
Expand Down Expand Up @@ -482,11 +492,11 @@ def update(self, *, timestep: int, timesteps: int) -> None:
self.value_learning_rate *= self.value_scheduler(timestep)

# record data
self.track_data("Loss / Policy loss", cumulative_policy_loss / len(sampled_batches))
self.track_data("Loss / Value loss", cumulative_value_loss / len(sampled_batches))
self.track_data("Loss / Policy loss", cumulative_policy_loss / self.cfg.mini_batches)
self.track_data("Loss / Value loss", cumulative_value_loss / self.cfg.mini_batches)

if self.cfg.entropy_loss_scale:
self.track_data("Loss / Entropy loss", cumulative_entropy_loss / len(sampled_batches))
self.track_data("Loss / Entropy loss", cumulative_entropy_loss / self.cfg.mini_batches)

self.track_data("Policy / Standard deviation", stddev.mean().item())

Expand Down
4 changes: 2 additions & 2 deletions skrl/agents/jax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,8 +376,8 @@ def record_transition(
if finished_episodes.size:

# storage cumulative rewards and timesteps
self._track_rewards.extend(self._cumulative_rewards[finished_episodes][:, 0].reshape(-1).tolist())
self._track_timesteps.extend(self._cumulative_timesteps[finished_episodes][:, 0].reshape(-1).tolist())
self._track_rewards.extend(self._cumulative_rewards[finished_episodes].tolist())
self._track_timesteps.extend(self._cumulative_timesteps[finished_episodes].tolist())

# reset the cumulative rewards and timesteps
self._cumulative_rewards[finished_episodes] = 0
Expand Down
38 changes: 24 additions & 14 deletions skrl/agents/jax/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,27 @@


# https://jax.readthedocs.io/en/latest/faq.html#strategy-1-jit-compiled-helper-function
@jax.jit
@functools.partial(jax.jit, static_argnames=("time_limit_bootstrap",))
def _compute_gae(
rewards: jax.Array,
terminated: jax.Array,
truncated: jax.Array,
values: jax.Array,
next_values: jax.Array,
last_values: jax.Array,
discount_factor: float = 0.99,
lambda_coefficient: float = 0.95,
time_limit_bootstrap: bool = False,
) -> jax.Array:
advantage = 0
advantages = jnp.zeros_like(rewards)
not_terminated = jnp.logical_not(terminated)
not_done = jnp.logical_not(jnp.logical_or(terminated, truncated) if time_limit_bootstrap else terminated)
memory_size = rewards.shape[0]

# advantages computation
for i in reversed(range(memory_size)):
next_values = values[i + 1] if i < memory_size - 1 else next_values
next_values = values[i + 1] if i < memory_size - 1 else last_values
advantage = (
rewards[i]
- values[i]
+ discount_factor * not_terminated[i] * (next_values + lambda_coefficient * advantage)
rewards[i] - values[i] + discount_factor * not_done[i] * (next_values + lambda_coefficient * advantage)
)
advantages = advantages.at[i].set(advantage)
# returns computation
Expand Down Expand Up @@ -241,6 +241,7 @@ def init(self, *, trainer_cfg: dict[str, Any] | None = None) -> None:
self.memory.create_tensor(name="actions", size=self.action_space, dtype=jnp.float32)
self.memory.create_tensor(name="rewards", size=1, dtype=jnp.float32)
self.memory.create_tensor(name="terminated", size=1, dtype=jnp.int8)
self.memory.create_tensor(name="truncated", size=1, dtype=jnp.int8)
self.memory.create_tensor(name="log_prob", size=1, dtype=jnp.float32)
self.memory.create_tensor(name="values", size=1, dtype=jnp.float32)
self.memory.create_tensor(name="returns", size=1, dtype=jnp.float32)
Expand Down Expand Up @@ -345,8 +346,15 @@ def record_transition(
rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps)

# time-limit (truncation) bootstrapping
if self.cfg.time_limit_bootstrap:
rewards += self.cfg.discount_factor * self._current_values * truncated
if self.cfg.time_limit_bootstrap and truncated.any():
inputs = {
"observations": self._observation_preprocessor(next_observations),
"states": self._state_preprocessor(next_states),
}
next_values, _ = self.value.act(inputs, role="value")
next_values = self._value_preprocessor(next_values, inverse=True)

rewards += self.cfg.discount_factor * next_values * truncated

# storage transition in memory
self.memory.add_samples(
Expand All @@ -355,6 +363,7 @@ def record_transition(
actions=actions,
rewards=rewards,
terminated=terminated,
truncated=truncated,
log_prob=self._current_log_prob,
values=self._current_values,
)
Expand Down Expand Up @@ -405,19 +414,18 @@ def update(self, *, timestep: int, timesteps: int) -> None:
returns, advantages = _compute_gae(
rewards=self.memory.get_tensor_by_name("rewards"),
terminated=self.memory.get_tensor_by_name("terminated"),
truncated=self.memory.get_tensor_by_name("truncated"),
values=values,
next_values=last_values,
last_values=last_values,
discount_factor=self.cfg.discount_factor,
lambda_coefficient=self.cfg.gae_lambda,
time_limit_bootstrap=self.cfg.time_limit_bootstrap,
)

self.memory.set_tensor_by_name("values", self._value_preprocessor(values, train=True))
self.memory.set_tensor_by_name("returns", self._value_preprocessor(returns, train=True))
self.memory.set_tensor_by_name("advantages", advantages)

# sample mini-batches from memory
sampled_batches = self.memory.sample_all(names=self._tensors_names, mini_batches=self.cfg.mini_batches)

cumulative_policy_loss = 0
cumulative_entropy_loss = 0
cumulative_value_loss = 0
Expand All @@ -435,7 +443,9 @@ def update(self, *, timestep: int, timesteps: int) -> None:
sampled_values,
sampled_returns,
sampled_advantages,
) in sampled_batches:
) in self.memory.sample(
names=self._tensors_names, batch_size=len(self.memory), mini_batches=self.cfg.mini_batches
):

inputs = {
"observations": self._observation_preprocessor(sampled_observations, train=not epoch),
Expand Down
Loading
Loading