Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prototype jax with ddpg #187

Merged
merged 32 commits into from
Jul 12, 2022
Merged
Changes from 1 commit
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
f127aa3
prototype jax with ddpg
vwxyzjn May 29, 2022
cbc5d88
Quick fix
vwxyzjn Jun 22, 2022
b4662c2
quick fix
vwxyzjn Jun 22, 2022
754a0b1
Commit changes - successful prototype
vwxyzjn Jun 24, 2022
223a8ff
Remove scripts
vwxyzjn Jun 25, 2022
85fbfe2
Simplify the implementation: careful with shape
vwxyzjn Jun 25, 2022
8ffbd26
Format
vwxyzjn Jun 25, 2022
c72cfb7
Remove code
vwxyzjn Jun 25, 2022
bfece78
formatting changes
vwxyzjn Jun 25, 2022
0710728
formatting change
vwxyzjn Jun 25, 2022
92d9d13
bug fix
vwxyzjn Jun 25, 2022
ee80f6b
correctly implementing keys
vwxyzjn Jun 26, 2022
0b30c57
these two lines are not necessary
vwxyzjn Jun 28, 2022
8e9f991
Adapting to the `TrainState` API
vwxyzjn Jun 28, 2022
38ca055
Simplify code
vwxyzjn Jun 28, 2022
3a58fcf
use `optax.incremental_update`
vwxyzjn Jun 29, 2022
207d09f
Also log q values
vwxyzjn Jun 29, 2022
6f4fa3d
Addresses #211
vwxyzjn Jun 29, 2022
52243ec
Merge branch 'master' into jax-ddpg
vwxyzjn Jun 29, 2022
9ec4ac5
update docs
vwxyzjn Jun 29, 2022
acb3293
Add jax benchmark experiments
vwxyzjn Jun 29, 2022
0e9d8f4
remove old files
vwxyzjn Jun 29, 2022
8226824
update benchmark scripts
vwxyzjn Jun 29, 2022
57230c3
update lock files
vwxyzjn Jun 29, 2022
29a0aef
Handle action space bounds
vwxyzjn Jun 30, 2022
5f0ed84
Merge branch 'master' into jax-ddpg
vwxyzjn Jun 30, 2022
024b8c5
Add docs
vwxyzjn Jun 30, 2022
34c2825
Typo
vwxyzjn Jun 30, 2022
e12c283
update CI
vwxyzjn Jun 30, 2022
7b5febd
bug fix and add docs link
vwxyzjn Jul 12, 2022
eb85ae6
Add a note explaining the speed
vwxyzjn Jul 12, 2022
003a770
Update ddpg docs
vwxyzjn Jul 12, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Also log q values
  • Loading branch information
vwxyzjn committed Jun 29, 2022
commit 207d09ff5f7badb9c5ea4e3d790fe4b7b84fca00
13 changes: 7 additions & 6 deletions cleanrl/ddpg_continuous_action_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,12 @@ def update_critic(
next_q_value = (rewards + (1 - dones) * args.gamma * (qf1_next_target)).reshape(-1)

def mse_loss(params):
return ((qf1.apply(params, observations, actions).squeeze() - next_q_value) ** 2).mean()
qf1_a_values = qf1.apply(params, observations, actions).squeeze()
return ((qf1_a_values - next_q_value) ** 2).mean(), qf1_a_values.mean()

qf1_loss_value, grads = jax.value_and_grad(mse_loss)(qf1_state.params)
(qf1_loss_value, qf1_a_values), grads = jax.value_and_grad(mse_loss, has_aux=True)(qf1_state.params)
qf1_state = qf1_state.apply_gradients(grads=grads)
return qf1_state, qf1_loss_value
return qf1_state, qf1_loss_value, qf1_a_values

@jax.jit
def update_actor(
Expand Down Expand Up @@ -247,7 +248,7 @@ def actor_loss(params):
# ALGO LOGIC: training.
if global_step > args.learning_starts:
data = rb.sample(args.batch_size)
qf1_state, qf1_loss_value = update_critic(
qf1_state, qf1_loss_value, qf1_a_values = update_critic(
actor_state,
qf1_state,
data.observations.numpy(),
Expand All @@ -257,7 +258,7 @@ def actor_loss(params):
data.dones.flatten().numpy(),
)
if global_step % args.policy_frequency == 0:
(actor_state, qf1_state, actor_loss_value) = update_actor(
actor_state, qf1_state, actor_loss_value = update_actor(
actor_state,
qf1_state,
data.observations.numpy(),
Expand All @@ -266,7 +267,7 @@ def actor_loss(params):
if global_step % 100 == 0:
writer.add_scalar("losses/qf1_loss", qf1_loss_value.item(), global_step)
writer.add_scalar("losses/actor_loss", actor_loss_value.item(), global_step)
# writer.add_scalar("losses/qf1_values", qf1_a_values.mean().item(), global_step)
writer.add_scalar("losses/qf1_values", qf1_a_values.item(), global_step)
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

Expand Down