Skip to content

Commit f24e842

Browse files
committed
code refactoring and modified models for td3 & vae
1 parent f7ae0fb commit f24e842

File tree

7 files changed

+103
-84
lines changed

7 files changed

+103
-84
lines changed

sagan/model.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ def __init__(self, img_channels, h_dim, img_size, num_classes=0):
1212
self.disc = nn.Sequential(
1313
ConvNormAct(img_channels, h_dim, 'sn', 'down', activation='lrelu', normalization='bn'),
1414
SA_Conv2d(h_dim),
15+
nn.BatchNorm2d(h_dim),
16+
nn.LeakyReLU(0.2),
1517
ConvNormAct(h_dim, h_dim*2, 'sn', 'down', activation='lrelu', normalization='bn'),
1618
ConvNormAct(h_dim*2, h_dim*4, 'sn', 'down', activation='lrelu', normalization='bn'),
1719
ConvNormAct(h_dim*4, h_dim*8, 'sn', 'down', activation='lrelu', normalization='bn'),
@@ -46,6 +48,8 @@ def __init__(self, h_dim, z_dim, img_channels, img_size, num_classes=0):
4648
ConditionalConvBNAct(h_dim*4, h_dim*2, 'sn', 'up', activation='relu', normalization='bn', num_classes=num_classes),
4749
ConditionalConvBNAct(h_dim*2, h_dim, 'sn', 'up', activation='relu', normalization='bn', num_classes=num_classes),
4850
SA_Conv2d(h_dim),
51+
ConditionalNorm(num_classes, h_dim) if num_classes > 0 else nn.BatchNorm2d(h_dim),
52+
nn.ReLU(),
4953
SN_ConvTranspose2d(in_channels=h_dim, out_channels=img_channels, kernel_size=4,
5054
stride=2, padding=1),
5155
nn.Tanh()
@@ -61,7 +65,7 @@ def forward(self, x, y=None):
6165
self.min_hw,
6266
self.min_hw)
6367
for layer in self.gen:
64-
if isinstance(layer, ConditionalConvBNAct) and y is not None:
68+
if isinstance(layer, (ConditionalNorm, ConditionalConvBNAct)) and y is not None:
6569
x_hat = layer(x_hat, y)
6670
else:
6771
x_hat = layer(x_hat)

td3/agent.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,16 @@ def __init__(self, env, alpha, beta, hidden_dims, tau,
4949

5050
# networks & optimizers
5151
if img_input:
52-
self.actor = ImageActor(in_channels, n_actions, self.max_action, order, depth, multiplier, 'actor').to(self.device)
52+
self.actor = ImageActor(in_channels, n_actions, hidden_dim, self.max_action, order, depth, multiplier, 'actor').to(self.device)
5353
self.critic_1 = ImageCritic(in_channels, n_actions, hidden_dim, action_embed_dim, order, depth, multiplier, 'critic_1').to(self.device)
5454
self.critic_2 = ImageCritic(in_channels, n_actions, hidden_dim, action_embed_dim, order, depth, multiplier, 'critic_2').to(self.device)
5555

56-
self.target_actor = ImageActor(in_channels, n_actions, self.max_action, order, depth, multiplier, 'target_actor').to(self.device)
56+
self.target_actor = ImageActor(in_channels, n_actions, hidden_dim, self.max_action, order, depth, multiplier, 'target_actor').to(self.device)
5757
self.target_critic_1 = ImageCritic(in_channels, n_actions, hidden_dim, action_embed_dim, order, depth, multiplier, 'target_critic_1').to(self.device)
5858
self.target_critic_2 = ImageCritic(in_channels, n_actions, hidden_dim, action_embed_dim, order, depth, multiplier, 'target_critic_2').to(self.device)
59-
59+
print('actor')
60+
print(self.actor)
61+
6062
# physics networks
6163
else:
6264
self.actor = Actor(state_space, hidden_dims, n_actions, self.max_action, 'actor').to(self.device)

td3/main.py

Lines changed: 50 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -29,20 +29,20 @@
2929
parser.add_argument('--img_input', action="store_true", default=False, help='Use image as states')
3030
parser.add_argument('--in_channels', type=int, default=3, help='Number of image channels for image input')
3131
parser.add_argument('--depth', type=int, default=3, help='Depth for CNN architecture for image input')
32-
parser.add_argument('--multiplier', type=int, default=32, help='Channel multiplier for CNN architecture for image input')
32+
parser.add_argument('--multiplier', type=int, default=16, help='Channel multiplier for CNN architecture for image input')
3333
parser.add_argument('--order', type=int, default=3, help='Store past (order) of frames for image input')
34-
parser.add_argument('--action_embed_dim', type=int, default=32, help='Embedding dimension for actions for image input')
35-
parser.add_argument('--hidden_dim', type=int, default=512, help='List of hidden dims for embedding networks')
34+
parser.add_argument('--action_embed_dim', type=int, default=256, help='Embedding dimension for actions for image input')
35+
parser.add_argument('--hidden_dim', type=int, default=256, help='Hidden dims for embedding networks')
3636
parser.add_argument('--crop_dim', type=int, default=32, help='Crop dim for image inputs')
3737

3838
# training hp params
3939
parser.add_argument('--n_episodes', type=int, default=1000, help='Number of episodes')
4040
parser.add_argument('--batch_size', type=int, default=512, help='Batch size')
41-
parser.add_argument('--alpha', type=float, default=0.001, help='Learning rate actor')
42-
parser.add_argument('--beta', type=float, default=0.001, help='Learning rate critic')
41+
parser.add_argument('--alpha', type=float, default=3e-4, help='Learning rate actor')
42+
parser.add_argument('--beta', type=float, default=3e-4, help='Learning rate critic')
4343
parser.add_argument('--warmup', type=int, default=1000, help='Number of warmup steps')
4444
parser.add_argument('--d', type=int, default=2, help='Skip iteration')
45-
parser.add_argument('--max_size', type=int, default=1000000, help='Replay buffer size')
45+
parser.add_argument('--max_size', type=int, default=100000, help='Replay buffer size')
4646
parser.add_argument('--no_render', action="store_true", default=False, help='Whether to render')
4747
parser.add_argument('--window_size', type=int, default=100, help='Score tracking moving average window size')
4848

@@ -58,6 +58,7 @@
5858
import pyvirtualdisplay
5959
_display = pyvirtualdisplay.Display(visible=False, size=(1400, 900))
6060
_ = _display.start()
61+
6162
# paths
6263
Path(args.checkpoint_dir).mkdir(parents=True, exist_ok=True)
6364
Path(args.log_dir).mkdir(parents=True, exist_ok=True)
@@ -78,46 +79,46 @@
7879
score_history = deque([], maxlen=args.window_size)
7980
episodes = tqdm(range(args.n_episodes))
8081

81-
for e in episodes:
82-
# resetting
83-
state = env.reset()
84-
if args.img_input:
85-
state_queue = deque(
86-
[preprocess_img(state['pixels'], args.crop_dim) for _ in range(args.order)],
87-
maxlen=args.order)
88-
state = torch.cat(list(state_queue), 1).cpu().numpy()
89-
done, score = False, 0
90-
91-
while not done:
92-
action = agent.choose_action(state)
93-
state_, reward, done, _ = env.step(action)
94-
if isinstance(reward, np.ndarray):
95-
reward = reward[0]
96-
if args.img_input:
97-
state_queue.append(preprocess_img(state_['pixels'], args.crop_dim))
98-
state_ = torch.cat(list(state_queue), 1).cpu().numpy()
99-
agent.remember(state, action, reward, state_, done)
100-
agent.learn()
101-
102-
# reset, log & render
103-
score += reward
104-
state = state_
105-
episodes.set_postfix({'Reward': reward})
106-
if args.no_render:
107-
continue
108-
env.render()
109-
110-
# logging
111-
score_history.append(score)
112-
moving_avg = sum(score_history) / len(score_history)
113-
agent.add_scalar('Average Score', moving_avg, global_step=e)
114-
115-
# save weights @ best score
116-
if moving_avg > best_score:
117-
best_score = moving_avg
118-
agent.save_networks()
119-
120-
tqdm.write(f'Episode: {e + 1}/{args.n_episodes}, \
121-
Episode Score: {score}, \
122-
Average Score: {moving_avg}, \
123-
Best Score: {best_score}')
82+
# for e in episodes:
83+
# # resetting
84+
# state = env.reset()
85+
# if args.img_input:
86+
# state_queue = deque(
87+
# [preprocess_img(state['pixels'], args.crop_dim) for _ in range(args.order)],
88+
# maxlen=args.order)
89+
# state = torch.cat(list(state_queue), 1).cpu().numpy()
90+
# done, score = False, 0
91+
92+
# while not done:
93+
# action = agent.choose_action(state)
94+
# state_, reward, done, _ = env.step(action)
95+
# if isinstance(reward, np.ndarray):
96+
# reward = reward[0]
97+
# if args.img_input:
98+
# state_queue.append(preprocess_img(state_['pixels'], args.crop_dim))
99+
# state_ = torch.cat(list(state_queue), 1).cpu().numpy()
100+
# agent.remember(state, action, reward, state_, done)
101+
# agent.learn()
102+
103+
# # reset, log & render
104+
# score += reward
105+
# state = state_
106+
# episodes.set_postfix({'Reward': reward, 'Iteration': agent.time_step})
107+
# if args.no_render:
108+
# continue
109+
# env.render()
110+
111+
# # logging
112+
# score_history.append(score)
113+
# moving_avg = sum(score_history) / len(score_history)
114+
# agent.add_scalar('Average Score', moving_avg, global_step=e)
115+
116+
# # save weights @ best score
117+
# if moving_avg > best_score:
118+
# best_score = moving_avg
119+
# agent.save_networks()
120+
121+
# tqdm.write(f'Episode: {e + 1}/{args.n_episodes}, \
122+
# Episode Score: {score}, \
123+
# Average Score: {moving_avg}, \
124+
# Best Score: {best_score}')

td3/networks.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def forward(self, state):
4646

4747

4848
class ImageActor(nn.Module):
49-
def __init__(self, in_channels, n_actions, max_action, order, depth, multiplier, name):
49+
def __init__(self, in_channels, n_actions, hidden_dim, max_action, order, depth, multiplier, name):
5050
super().__init__()
5151
self.name = name
5252
self.max_action = max_action
@@ -65,10 +65,13 @@ def __init__(self, in_channels, n_actions, max_action, order, depth, multiplier,
6565
*convs,
6666
nn.AdaptiveAvgPool2d(1),
6767
nn.Flatten())
68-
self.fc = nn.Linear(order*ch, n_actions)
68+
self.fc = nn.Sequential(
69+
nn.Linear(order * ch, hidden_dim),
70+
nn.ReLU(),
71+
nn.Linear(hidden_dim, n_actions))
6972

7073
def forward(self, imgs):
71-
img_feature = [self.convs(imgs[:, i*self.order:(i+1)*self.order, :, :]) for i in range(self.order)]
74+
img_feature = [self.convs(img) for img in imgs.chunk(self.order, 1)]
7275
img_feature = torch.cat(img_feature, 1)
7376
return torch.tanh(self.fc(img_feature)) * self.max_action
7477

@@ -104,9 +107,9 @@ def __init__(self, in_channels, n_actions, hidden_dim, action_embed_dim, order,
104107
nn.Linear(hidden_dim, 1)
105108
)
106109

107-
def forward(self, state, action):
108-
img_embedding = [self.avg_pool(self.convs(
109-
state[:, i*self.order:(i+1)*self.order, :, :])).squeeze() for i in range(self.order)]
110+
def forward(self, states, action):
111+
img_embedding = [self.avg_pool(
112+
self.convs(state)).squeeze() for state in states.chunk(self.order, 1)]
110113
img_embedding = torch.cat(img_embedding, 1)
111114
action_embedding = self.action_head(action)
112115
combined_embedding = torch.cat([img_embedding, action_embedding], dim=1)

vae/loss.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33

44

55
class VAELoss(nn.Module):
6-
def __init__(self, recon_type, beta):
6+
def __init__(self, recon_type, beta, reduction='sum'):
77
super().__init__()
88
self.recon_type = recon_type
99
if recon_type == 'l2':
10-
self.recon = nn.MSELoss(reduction='sum')
10+
self.recon = nn.MSELoss(reduction=reduction)
1111
elif recon_type == 'l1':
12-
self.recon = nn.L1Loss(reduction='sum')
12+
self.recon = nn.L1Loss(reduction=reduction)
1313
elif recon_type == 'bce':
14-
self.recon = nn.BCELoss(reduction='sum')
14+
self.recon = nn.BCELoss(reduction=reduction)
1515
else:
1616
raise NotImplementedError
1717
self.beta = beta
@@ -29,15 +29,16 @@ def forward(self, x, x_hat, mu, logvar):
2929

3030

3131
class GroupSparsityLoss(nn.Module):
32-
def __init__(self, n_elements):
32+
def __init__(self, n_elements, rho=0.05):
3333
super().__init__()
3434
self.n_elements = n_elements
35+
self.rho = torch.tensor(rho)
3536

3637
def forward(self, z):
38+
rho, n_elements = self.rho, self.n_elements
3739
# flatten latent variables & get dims
3840
z = z.view(z.size(0), -1)
3941
batch_size, z_dim = z.size()
40-
n_elements = self.n_elements
4142
assert z_dim % n_elements == 0
4243
groups = z_dim // n_elements
4344

@@ -47,4 +48,9 @@ def forward(self, z):
4748
).unsqueeze(0).repeat(batch_size, 1, 1).to(z.device)
4849
z_groups = z.unsqueeze(1).repeat(1, groups, 1)
4950
masked_z = z_groups * mask
50-
return masked_z.norm(p=2, dim=-1).sum(-1).mean()
51+
rho_hat = masked_z.norm(p=2, dim=-1)
52+
53+
sparsity_penalty = (
54+
(rho * (torch.log(rho) - torch.log(rho_hat))) + (
55+
(1 - rho) * (torch.log(1 - rho) - torch.log(1 - rho_hat)))).sum(-1)
56+
return sparsity_penalty.mean()

vae/main.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,16 @@
3030
parser.add_argument('--n_pelements', type=int, default=4096, help='elements per gorup for projected')
3131

3232
# loss fn
33-
parser.add_argument('--beta', type=float, default=5., help='Beta hyperparam for KLD Loss')
34-
parser.add_argument('--gamma', type=float, default=1e6, help='gamma hyperparam for Sparsity Loss')
33+
parser.add_argument('--beta', type=float, default=0.1, help='Beta hyperparam for KLD Loss')
34+
parser.add_argument('--gamma', type=float, default=1e-4, help='gamma hyperparam for Sparsity Loss')
3535
parser.add_argument('--recon', type=str, default='bce', help='Reconstruction loss type [bce, l2]')
36+
parser.add_argument('--reduction', type=str, default='mean', help='Loss reduction method [mean, sum]')
3637

3738
# training hyperparams
3839
parser.add_argument('--device_ids', type=list, default=[0, 1], help='List of GPU devices')
3940
parser.add_argument('--lr', type=float, default=0.0005, help='Learning rate for generators')
4041
parser.add_argument('--betas', type=tuple, default=(0.5, 0.999), help='Betas for Adam optimizer')
41-
parser.add_argument('--n_epochs', type=int, default=200, help='Number of epochs')
42+
parser.add_argument('--n_epochs', type=int, default=50, help='Number of epochs')
4243
parser.add_argument('--batch_size', type=int, default=256, help='Batch size')
4344
parser.add_argument('--data_parallel', action="store_true", default=False, help='train with data parallel')
4445

@@ -78,7 +79,7 @@
7879
model.apply(initialize_modules)
7980
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=args.betas)
8081
scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lambda epoch: 0.995)
81-
criterion = VAELoss(args.recon, args.beta)
82+
criterion = VAELoss(args.recon, args.beta, args.reduction)
8283

8384
# fixed z to see how model changes on the same latent vectors
8485
fixed_z = torch.randn(args.sample_size, args.z_dim).to(device)
@@ -103,11 +104,9 @@
103104

104105
# if there is sparsity loss
105106
if args.plus:
106-
std = logvar.mul(0.5).exp_()
107-
sparsity_loss = GroupSparsityLoss(args.n_zelements)(mu) + \
108-
GroupSparsityLoss(args.n_pelements)(z_p)
109-
loss += args.gamma * sparsity_loss
110-
s_loss.append(args.gamma * sparsity_loss.item())
107+
sparsity_loss = args.gamma * GroupSparsityLoss(args.n_zelements)(mu)
108+
loss += sparsity_loss
109+
s_loss.append(sparsity_loss.item())
111110

112111
# logging and updating parameters
113112
losses.append(loss.item())

vae/model.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,14 @@ class VAE_Plus(VAE):
5959
def __init__(self, z_dim, model_dim, img_size, img_channels):
6060
super().__init__(z_dim, model_dim, img_size, img_channels)
6161
self.encoder = nn.Sequential(
62-
ConvNormAct(img_channels, model_dim, 'sn', 'down', activation='lrelu'),
62+
ConvNormAct(img_channels, model_dim, 'sn', 'down', activation='relu'),
6363
SA_Conv2d(model_dim),
64-
ConvNormAct(model_dim, model_dim * 2, 'sn', 'down', activation='lrelu'),
65-
ConvNormAct(model_dim * 2, model_dim * 4, 'sn', 'down', activation='lrelu'),
66-
ConvNormAct(model_dim * 4, model_dim * 8, 'sn', None, activation='lrelu'),
67-
ConvNormAct(model_dim * 8, model_dim * 8, 'sn', None, activation='lrelu'),
64+
nn.BatchNorm2d(model_dim),
65+
nn.ReLU(),
66+
ConvNormAct(model_dim, model_dim * 2, 'sn', 'down', activation='relu'),
67+
ConvNormAct(model_dim * 2, model_dim * 4, 'sn', 'down', activation='relu'),
68+
ConvNormAct(model_dim * 4, model_dim * 8, 'sn', None, activation='relu'),
69+
ConvNormAct(model_dim * 8, model_dim * 8, 'sn', None, activation='relu'),
6870
nn.Flatten(),
6971
nn.Linear((img_size // (2**3))**2 * model_dim * 8, z_dim * 2)
7072
)
@@ -74,11 +76,13 @@ def __init__(self, z_dim, model_dim, img_size, img_channels):
7476
nn.LeakyReLU(0.2)
7577
)
7678
self.decoder = nn.Sequential(
77-
ConvNormAct(model_dim * 8, model_dim * 8, 'sn', None, activation='lrelu'),
78-
ConvNormAct(model_dim * 8, model_dim * 4, 'sn', None, activation='lrelu'),
79-
ConvNormAct(model_dim * 4, model_dim * 2, 'sn', 'up', activation='lrelu'),
80-
ConvNormAct(model_dim * 2, model_dim, 'sn', 'up', activation='lrelu'),
79+
ConvNormAct(model_dim * 8, model_dim * 8, 'sn', None, activation='relu'),
80+
ConvNormAct(model_dim * 8, model_dim * 4, 'sn', None, activation='relu'),
81+
ConvNormAct(model_dim * 4, model_dim * 2, 'sn', 'up', activation='relu'),
82+
ConvNormAct(model_dim * 2, model_dim, 'sn', 'up', activation='relu'),
8183
SA_Conv2d(model_dim),
84+
nn.BatchNorm2d(model_dim),
85+
nn.ReLU(),
8286
SN_ConvTranspose2d(in_channels=model_dim, out_channels=img_channels, kernel_size=4, stride=2, padding=1),
8387
nn.Tanh()
8488
)

0 commit comments

Comments
 (0)