2
2
import torch .nn as nn
3
3
import torch .nn .functional as F
4
4
5
+ from config import gamma
6
+
5
7
def set_init (layers ):
6
8
for layer in layers :
7
9
nn .init .normal_ (layer .weight , mean = 0. , std = 0.1 )
@@ -13,26 +15,27 @@ def __init__(self, num_inputs, num_outputs):
13
15
self .num_inputs = num_inputs
14
16
self .num_outputs = num_outputs
15
17
16
- self .fc1 = nn .Linear (num_inputs , 128 )
17
- self .fc2 = nn .Linear (128 , 128 )
18
+ self .fc = nn .Linear (num_inputs , 128 )
18
19
self .fc_actor = nn .Linear (128 , num_outputs )
19
-
20
- self .fc3 = nn .Linear (num_inputs , 128 )
21
- self .fc4 = nn .Linear (128 , 128 )
22
20
self .fc_critic = nn .Linear (128 , 1 )
23
21
24
- set_init ([self .fc1 , self .fc2 , self .fc_actor , self .fc3 , self .fc4 , self .fc_critic ])
22
+ for m in self .modules ():
23
+ if isinstance (m , nn .Linear ):
24
+ nn .init .xavier_uniform (m .weight )
25
25
26
26
def forward (self , input ):
27
- x = F .relu (self .fc1 (input ))
28
- x = F .relu (self .fc2 (x ))
27
+ x = F .relu (self .fc (input ))
29
28
policy = F .softmax (self .fc_actor (x ))
30
-
31
- y = F .relu (self .fc3 (input ))
32
- y = F .relu (self .fc4 (y ))
33
- value = self .fc_critic (y )
29
+ value = self .fc_critic (x )
34
30
return policy , value
35
31
32
+ def get_action (self , input ):
33
+ policy , _ = self .forward (input )
34
+ policy = policy [0 ].data .numpy ()
35
+
36
+ action = np .random .choice (self .num_outputs , 1 , p = policy )[0 ]
37
+ return action
38
+
36
39
37
40
class GlobalModel (Model ):
38
41
def __init__ (self , num_inputs , num_outputs ):
@@ -43,34 +46,41 @@ class LocalModel(Model):
43
46
def __init__ (self , num_inputs , num_outputs ):
44
47
super (LocalModel , self ).__init__ (num_inputs , num_outputs )
45
48
46
- def push_to_global_model (self , batch , global_model , global_optimizer , args ):
49
+ def push_to_global_model (self , batch , global_model , global_optimizer ):
47
50
states = torch .stack (batch .state )
48
51
next_states = torch .stack (batch .next_state )
49
- actions = torch .Tensor (batch .action ). long ( )
52
+ actions = torch .stack (batch .action )
50
53
rewards = torch .Tensor (batch .reward )
51
54
masks = torch .Tensor (batch .mask )
52
55
53
- policy , value = self .forward (states [0 ])
56
+ policy , value = self .forward (states )
57
+ policy = policy .view (- 1 , self .num_outputs )
58
+ value = value .view (- 1 )
59
+
54
60
_ , last_value = self .forward (next_states [- 1 ])
55
61
56
- running_returns = last_value [0 ]
62
+ running_return = last_value [0 ].data
63
+ running_returns = torch .zeros (rewards .size ())
57
64
for t in reversed (range (0 , len (rewards ))):
58
- running_returns = rewards [t ] + args .gamma * running_returns * masks [t ]
65
+ running_return = rewards [t ] + gamma * running_return * masks [t ]
66
+ running_returns [t ] = running_return
59
67
60
- pred = running_returns
61
- td_error = pred - value [0 ]
62
68
63
- log_policy = torch .log (policy [0 ] + 1e-5 )[actions [0 ]]
64
- loss1 = - log_policy * td_error .item ()
65
- loss2 = F .mse_loss (value [0 ], pred .detach ())
66
- entropy = torch .log (policy + 1e-5 ) * policy
67
- loss = loss1 + loss2 - 0.01 * entropy .sum ()
69
+ td_error = running_returns - value .detach ()
70
+ log_policy = (torch .log (policy + 1e-10 ) * actions ).sum (dim = 1 , keepdim = True )
71
+ loss_policy = - log_policy * td_error
72
+ loss_value = torch .pow (td_error , 2 )
73
+ entropy = (torch .log (policy + 1e-10 ) * policy ).sum (dim = 1 , keepdim = True )
74
+
75
+ loss = (loss_policy + loss_value - 0.01 * entropy ).mean ()
68
76
69
77
global_optimizer .zero_grad ()
70
78
loss .backward ()
71
79
for lp , gp in zip (self .parameters (), global_model .parameters ()):
72
80
gp ._grad = lp .grad
73
81
global_optimizer .step ()
74
82
83
+ return loss
84
+
75
85
def pull_from_global_model (self , global_model ):
76
86
self .load_state_dict (global_model .state_dict ())
0 commit comments