33from torch .nn import functional as F
44from copy import deepcopy
55
6- from policy .networks import ActorCritic , Actor , Critic
6+ from policy .networks import ActorCritic , Actor , Critic , SACActor , SACCritic , SACValue
77from policy .utils import ReplayBuffer , OUActionNoise , clip_action , GaussianActionNoise
8-
8+ torch . autograd . set_detect_anomaly ( True )
99
1010class BlackJackAgent :
1111 def __init__ (self , method , env , function = 'V' , gamma = 0.99 , epsilon = 0.1 ):
@@ -332,7 +332,7 @@ def choose_action(self, observation, test):
332332 with torch .no_grad ():
333333 action = self .actor (observation )
334334 if not test :
335- action = action + self .noise (action .size ())
335+ action = action + self .noise (action .size ()). to ( self . device )
336336 self .actor .train ()
337337 action = action .cpu ().detach ().numpy ()
338338 # clip noised action to ensure not out of bounds
@@ -357,12 +357,14 @@ def update(self):
357357
358358 # calculate targets & only update online critic network
359359 self .critic .optimizer .zero_grad ()
360+ self .critic2 .optimizer .zero_grad ()
360361 with torch .no_grad ():
361362 # y <- r + gamma * min_(i=1,2) Q_(theta'_i)(s', a_telda)
362363 target_actions = self .target_actor (next_states )
363364 target_actions += self .noise (
364- target_actions .size (), clip = self .action_clip , sigma = self .action_sigma )
365- target_actions = clip_action (target_actions , self .max_action )
365+ target_actions .size (), clip = self .action_clip , sigma = self .action_sigma ).to (self .device )
366+ target_actions = clip_action (target_actions .cpu ().numpy (), self .max_action )
367+ target_actions = torch .from_numpy (target_actions ).to (self .device )
366368 q_primes1 = self .target_critic (next_states , target_actions ).squeeze ()
367369 q_primes2 = self .target_critic2 (next_states , target_actions ).squeeze ()
368370 q_primes = torch .min (q_primes1 , q_primes2 )
@@ -375,4 +377,117 @@ def update(self):
375377 critic_loss = critic_loss1 + critic_loss2
376378 critic_loss .backward ()
377379 self .critic .optimizer .step ()
380+ self .critic2 .optimizer .step ()
378381 return self .actor_loss , critic_loss .item ()
382+
383+
384+ class SACAgent :
385+ def __init__ (self , state_dim , action_dim , hidden_dims , max_action , gamma ,
386+ tau , reward_scale , lr , batch_size , maxsize , checkpoint ):
387+ self .device = torch .device ('cuda:0' if torch .cuda .is_available () else 'cpu' )
388+ self .gamma = gamma
389+ self .tau = tau
390+ self .reward_scale = reward_scale
391+ self .batch_size = batch_size
392+
393+ self .memory = ReplayBuffer (state_dim , action_dim , maxsize )
394+ self .critic1 = SACCritic (* state_dim , * action_dim , hidden_dims , lr ,
395+ checkpoint , 'Critic' )
396+ self .critic2 = SACCritic (* state_dim , * action_dim , hidden_dims ,
397+ lr , checkpoint , 'Critic2' )
398+ self .actor = SACActor (* state_dim , * action_dim , hidden_dims , max_action ,
399+ lr , checkpoint , 'Actor' )
400+ self .value = SACValue (* state_dim , hidden_dims ,
401+ lr , checkpoint , 'Valuator' )
402+ self .target_value = self .get_target_network (self .value )
403+ self .target_value .name = 'Target_Valuator'
404+
405+ def get_target_network (self , online_network , freeze_weights = True ):
406+ target_network = deepcopy (online_network )
407+ if freeze_weights :
408+ for param in target_network .parameters ():
409+ param .requires_grad = False
410+ return target_network
411+
412+ def choose_action (self , observation , test ):
413+ self .actor .eval ()
414+ observation = torch .from_numpy (observation ).to (self .device )
415+ with torch .no_grad ():
416+ action , _ = self .actor (observation )
417+ self .actor .train ()
418+ action = action .cpu ().detach ().numpy ()
419+ return action
420+
421+ def update (self ):
422+ experiences = self .memory .sample_transition (self .batch_size )
423+ states , actions , rewards , next_states , dones = [data .to (self .device ) for data in experiences ]
424+
425+ ###### UPDATE VALUATOR ######
426+ self .value .optimizer .zero_grad ()
427+ with torch .no_grad ():
428+ policy_actions , log_probs = self .actor (states , reparameterize = False )
429+ action_values1 = self .critic1 (states , policy_actions ).squeeze ()
430+ action_values2 = self .critic2 (states , policy_actions ).squeeze ()
431+ action_values = torch .min (action_values1 , action_values2 )
432+ target = action_values - log_probs .squeeze ()
433+ values = self .value (states ).squeeze ()
434+ value_loss = 0.5 * F .mse_loss (target , values )
435+ value_loss .backward ()
436+ self .value .optimizer .step ()
437+
438+ ###### UPDATE CRITIC ######
439+ self .critic1 .optimizer .zero_grad ()
440+ self .critic2 .optimizer .zero_grad ()
441+ with torch .no_grad ():
442+ v_hat = self .target_value (next_states ).squeeze () * (~ dones )
443+ targets = rewards * self .reward_scale + self .gamma * v_hat
444+ qs1 = self .critic1 (states , actions ).squeeze ()
445+ qs2 = self .critic2 (states , actions ).squeeze ()
446+ critic_loss1 = 0.5 * F .mse_loss (targets , qs1 )
447+ critic_loss2 = 0.5 * F .mse_loss (targets , qs2 )
448+ critic_loss = critic_loss1 + critic_loss2
449+ critic_loss .backward ()
450+ self .critic1 .optimizer .step ()
451+ self .critic2 .optimizer .step ()
452+
453+ ###### UPDATE ACTOR ######
454+ self .actor .optimizer .zero_grad ()
455+ actions , log_probs = self .actor (states )
456+ action_values1 = self .critic1 (states , actions ).squeeze ()
457+ action_values2 = self .critic2 (states , actions ).squeeze ()
458+ action_values = torch .min (action_values1 , action_values2 )
459+ actor_loss = torch .mean (log_probs .squeeze () - action_values )
460+ actor_loss .backward ()
461+ self .actor .optimizer .step ()
462+
463+ ###### UPDATE TARGET VALUE ######
464+ self .update_target_network (self .value , self .target_value )
465+
466+ return value_loss .item (), critic_loss .item (), actor_loss .item ()
467+
468+ def update_target_network (self , src , tgt ):
469+ for src_weight , tgt_weight in zip (src .parameters (), tgt .parameters ()):
470+ tgt_weight .data = tgt_weight .data * self .tau + src_weight .data * (1. - self .tau )
471+
472+ def store_transition (self , state , action , reward , next_state , done ):
473+ state = torch .tensor (state )
474+ action = torch .tensor (action )
475+ reward = torch .tensor (reward )
476+ next_state = torch .tensor (next_state )
477+ done = torch .tensor (done , dtype = torch .bool )
478+ self .memory .store_transition (state , action , reward , next_state , done )
479+
480+ def save_models (self ):
481+ self .critic1 .save_checkpoint ()
482+ self .critic2 .save_checkpoint ()
483+ self .actor .save_checkpoint ()
484+ self .value .save_checkpoint ()
485+ self .target_value .save_checkpoint ()
486+
487+ def load_models (self ):
488+ self .critic1 .load_checkpoint ()
489+ self .critic2 .load_checkpoint ()
490+ self .actor .load_checkpoint ()
491+ self .value .load_checkpoint ()
492+ self .target_value .load_checkpoint ()
493+
0 commit comments