1
+ import torch
2
+ import torch .nn as nn
3
+ import torch .nn .functional as F
4
+ import numpy as np
5
+
6
+ from config import batch_size , gamma , quantile_embedding_dim , num_tau_sample , num_tau_prime_sample , num_quantile_sample
7
+
8
+ class QRDQN (nn .Module ):
9
+ def __init__ (self , num_inputs , num_outputs ):
10
+ super (QRDQN , self ).__init__ ()
11
+ self .num_inputs = num_inputs
12
+ self .num_outputs = num_outputs
13
+
14
+ self .fc1 = nn .Linear (num_inputs , 128 )
15
+ self .fc2 = nn .Linear (128 , num_outputs )
16
+ self .phi = nn .Linear (quantile_embedding_dim , 128 )
17
+
18
+ for m in self .modules ():
19
+ if isinstance (m , nn .Linear ):
20
+ nn .init .xavier_uniform (m .weight )
21
+
22
+ def forward (self , state , tau , num_quantiles ):
23
+ input_size = state .size ()[0 ] # batch_size(train) or 1(get_action)
24
+ tau = tau .expand (input_size * num_quantiles , quantile_embedding_dim )
25
+ pi_mtx = torch .Tensor (np .pi * np .arange (0 , quantile_embedding_dim )).expand (input_size * num_quantiles , quantile_embedding_dim )
26
+ cos_tau = torch .cos (tau * pi_mtx )
27
+
28
+ phi = self .phi (cos_tau )
29
+ phi = F .relu (phi )
30
+
31
+ state_tile = state .expand (input_size , num_quantiles , self .num_inputs )
32
+ state_tile = state_tile .flatten ().view (- 1 , self .num_inputs )
33
+
34
+ x = F .relu (self .fc1 (state_tile ))
35
+ x = self .fc2 (x * phi )
36
+ z = x .view (- 1 , num_quantiles , self .num_outputs )
37
+
38
+ z = z .transpose (1 , 2 ) # [input_size, num_output, num_quantile]
39
+ return z
40
+
41
+ def get_action (self , state ):
42
+ tau = torch .Tensor (np .random .rand (num_quantile_sample , 1 ) * 0.5 ) # CVaR
43
+ z = self .forward (state , tau , num_quantile_sample )
44
+ q = z .mean (dim = 2 , keepdim = True )
45
+ action = torch .argmax (q )
46
+ return action .item ()
47
+
48
+ @classmethod
49
+ def train_model (cls , online_net , target_net , optimizer , batch ):
50
+ states = torch .stack (batch .state )
51
+ next_states = torch .stack (batch .next_state )
52
+ actions = torch .Tensor (batch .action ).long ()
53
+ rewards = torch .Tensor (batch .reward )
54
+ masks = torch .Tensor (batch .mask )
55
+
56
+ tau = torch .Tensor (np .random .rand (batch_size * num_tau_sample , 1 ))
57
+ z = online_net (states , tau , num_tau_sample )
58
+ action = actions .unsqueeze (1 ).unsqueeze (1 ).expand (- 1 , 1 , num_tau_sample )
59
+ z_a = z .gather (1 , action ).squeeze (1 )
60
+
61
+ tau_prime = torch .Tensor (np .random .rand (batch_size * num_tau_prime_sample , 1 ))
62
+ next_z = target_net (next_states , tau_prime , num_tau_prime_sample )
63
+ next_action = next_z .mean (dim = 2 ).max (1 )[1 ]
64
+ next_action = next_action .unsqueeze (1 ).unsqueeze (1 ).expand (batch_size , 1 , num_tau_prime_sample )
65
+ next_z_a = next_z .gather (1 , next_action ).squeeze (1 )
66
+
67
+ T_z = rewards .unsqueeze (1 ) + gamma * next_z_a * masks .unsqueeze (1 )
68
+
69
+ T_z_tile = T_z .view (- 1 , num_tau_prime_sample , 1 ).expand (- 1 , num_tau_prime_sample , num_tau_sample )
70
+ z_a_tile = z_a .view (- 1 , 1 , num_tau_sample ).expand (- 1 , num_tau_prime_sample , num_tau_sample )
71
+
72
+ error_loss = T_z_tile - z_a_tile
73
+ huber_loss = nn .SmoothL1Loss (reduction = 'none' )(T_z_tile , z_a_tile )
74
+ tau = torch .arange (0 , 1 , 1 / num_tau_sample ).view (1 , num_tau_sample )
75
+
76
+ loss = (tau - (error_loss < 0 ).float ()).abs () * huber_loss
77
+ loss = loss .mean (dim = 2 ).sum (dim = 1 ).mean ()
78
+
79
+ optimizer .zero_grad ()
80
+ loss .backward ()
81
+ optimizer .step ()
82
+
83
+ return loss
0 commit comments