3
3
from torch .nn import functional as F
4
4
from copy import deepcopy
5
5
6
- from policy .networks import ActorCritic , Actor , Critic
6
+ from policy .networks import ActorCritic , Actor , Critic , SACActor , SACCritic , SACValue
7
7
from policy .utils import ReplayBuffer , OUActionNoise , clip_action , GaussianActionNoise
8
-
8
+ torch . autograd . set_detect_anomaly ( True )
9
9
10
10
class BlackJackAgent :
11
11
def __init__ (self , method , env , function = 'V' , gamma = 0.99 , epsilon = 0.1 ):
@@ -332,7 +332,7 @@ def choose_action(self, observation, test):
332
332
with torch .no_grad ():
333
333
action = self .actor (observation )
334
334
if not test :
335
- action = action + self .noise (action .size ())
335
+ action = action + self .noise (action .size ()). to ( self . device )
336
336
self .actor .train ()
337
337
action = action .cpu ().detach ().numpy ()
338
338
# clip noised action to ensure not out of bounds
@@ -357,12 +357,14 @@ def update(self):
357
357
358
358
# calculate targets & only update online critic network
359
359
self .critic .optimizer .zero_grad ()
360
+ self .critic2 .optimizer .zero_grad ()
360
361
with torch .no_grad ():
361
362
# y <- r + gamma * min_(i=1,2) Q_(theta'_i)(s', a_telda)
362
363
target_actions = self .target_actor (next_states )
363
364
target_actions += self .noise (
364
- target_actions .size (), clip = self .action_clip , sigma = self .action_sigma )
365
- target_actions = clip_action (target_actions , self .max_action )
365
+ target_actions .size (), clip = self .action_clip , sigma = self .action_sigma ).to (self .device )
366
+ target_actions = clip_action (target_actions .cpu ().numpy (), self .max_action )
367
+ target_actions = torch .from_numpy (target_actions ).to (self .device )
366
368
q_primes1 = self .target_critic (next_states , target_actions ).squeeze ()
367
369
q_primes2 = self .target_critic2 (next_states , target_actions ).squeeze ()
368
370
q_primes = torch .min (q_primes1 , q_primes2 )
@@ -375,4 +377,117 @@ def update(self):
375
377
critic_loss = critic_loss1 + critic_loss2
376
378
critic_loss .backward ()
377
379
self .critic .optimizer .step ()
380
+ self .critic2 .optimizer .step ()
378
381
return self .actor_loss , critic_loss .item ()
382
+
383
+
384
+ class SACAgent :
385
+ def __init__ (self , state_dim , action_dim , hidden_dims , max_action , gamma ,
386
+ tau , reward_scale , lr , batch_size , maxsize , checkpoint ):
387
+ self .device = torch .device ('cuda:0' if torch .cuda .is_available () else 'cpu' )
388
+ self .gamma = gamma
389
+ self .tau = tau
390
+ self .reward_scale = reward_scale
391
+ self .batch_size = batch_size
392
+
393
+ self .memory = ReplayBuffer (state_dim , action_dim , maxsize )
394
+ self .critic1 = SACCritic (* state_dim , * action_dim , hidden_dims , lr ,
395
+ checkpoint , 'Critic' )
396
+ self .critic2 = SACCritic (* state_dim , * action_dim , hidden_dims ,
397
+ lr , checkpoint , 'Critic2' )
398
+ self .actor = SACActor (* state_dim , * action_dim , hidden_dims , max_action ,
399
+ lr , checkpoint , 'Actor' )
400
+ self .value = SACValue (* state_dim , hidden_dims ,
401
+ lr , checkpoint , 'Valuator' )
402
+ self .target_value = self .get_target_network (self .value )
403
+ self .target_value .name = 'Target_Valuator'
404
+
405
+ def get_target_network (self , online_network , freeze_weights = True ):
406
+ target_network = deepcopy (online_network )
407
+ if freeze_weights :
408
+ for param in target_network .parameters ():
409
+ param .requires_grad = False
410
+ return target_network
411
+
412
+ def choose_action (self , observation , test ):
413
+ self .actor .eval ()
414
+ observation = torch .from_numpy (observation ).to (self .device )
415
+ with torch .no_grad ():
416
+ action , _ = self .actor (observation )
417
+ self .actor .train ()
418
+ action = action .cpu ().detach ().numpy ()
419
+ return action
420
+
421
+ def update (self ):
422
+ experiences = self .memory .sample_transition (self .batch_size )
423
+ states , actions , rewards , next_states , dones = [data .to (self .device ) for data in experiences ]
424
+
425
+ ###### UPDATE VALUATOR ######
426
+ self .value .optimizer .zero_grad ()
427
+ with torch .no_grad ():
428
+ policy_actions , log_probs = self .actor (states , reparameterize = False )
429
+ action_values1 = self .critic1 (states , policy_actions ).squeeze ()
430
+ action_values2 = self .critic2 (states , policy_actions ).squeeze ()
431
+ action_values = torch .min (action_values1 , action_values2 )
432
+ target = action_values - log_probs .squeeze ()
433
+ values = self .value (states ).squeeze ()
434
+ value_loss = 0.5 * F .mse_loss (target , values )
435
+ value_loss .backward ()
436
+ self .value .optimizer .step ()
437
+
438
+ ###### UPDATE CRITIC ######
439
+ self .critic1 .optimizer .zero_grad ()
440
+ self .critic2 .optimizer .zero_grad ()
441
+ with torch .no_grad ():
442
+ v_hat = self .target_value (next_states ).squeeze () * (~ dones )
443
+ targets = rewards * self .reward_scale + self .gamma * v_hat
444
+ qs1 = self .critic1 (states , actions ).squeeze ()
445
+ qs2 = self .critic2 (states , actions ).squeeze ()
446
+ critic_loss1 = 0.5 * F .mse_loss (targets , qs1 )
447
+ critic_loss2 = 0.5 * F .mse_loss (targets , qs2 )
448
+ critic_loss = critic_loss1 + critic_loss2
449
+ critic_loss .backward ()
450
+ self .critic1 .optimizer .step ()
451
+ self .critic2 .optimizer .step ()
452
+
453
+ ###### UPDATE ACTOR ######
454
+ self .actor .optimizer .zero_grad ()
455
+ actions , log_probs = self .actor (states )
456
+ action_values1 = self .critic1 (states , actions ).squeeze ()
457
+ action_values2 = self .critic2 (states , actions ).squeeze ()
458
+ action_values = torch .min (action_values1 , action_values2 )
459
+ actor_loss = torch .mean (log_probs .squeeze () - action_values )
460
+ actor_loss .backward ()
461
+ self .actor .optimizer .step ()
462
+
463
+ ###### UPDATE TARGET VALUE ######
464
+ self .update_target_network (self .value , self .target_value )
465
+
466
+ return value_loss .item (), critic_loss .item (), actor_loss .item ()
467
+
468
+ def update_target_network (self , src , tgt ):
469
+ for src_weight , tgt_weight in zip (src .parameters (), tgt .parameters ()):
470
+ tgt_weight .data = tgt_weight .data * self .tau + src_weight .data * (1. - self .tau )
471
+
472
+ def store_transition (self , state , action , reward , next_state , done ):
473
+ state = torch .tensor (state )
474
+ action = torch .tensor (action )
475
+ reward = torch .tensor (reward )
476
+ next_state = torch .tensor (next_state )
477
+ done = torch .tensor (done , dtype = torch .bool )
478
+ self .memory .store_transition (state , action , reward , next_state , done )
479
+
480
+ def save_models (self ):
481
+ self .critic1 .save_checkpoint ()
482
+ self .critic2 .save_checkpoint ()
483
+ self .actor .save_checkpoint ()
484
+ self .value .save_checkpoint ()
485
+ self .target_value .save_checkpoint ()
486
+
487
+ def load_models (self ):
488
+ self .critic1 .load_checkpoint ()
489
+ self .critic2 .load_checkpoint ()
490
+ self .actor .load_checkpoint ()
491
+ self .value .load_checkpoint ()
492
+ self .target_value .load_checkpoint ()
493
+
0 commit comments