Skip to content

Commit

Permalink
code refactoring and modified models for td3 & vae
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrewzh112 committed Dec 28, 2020
1 parent f7ae0fb commit f24e842
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 84 deletions.
6 changes: 5 additions & 1 deletion sagan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ def __init__(self, img_channels, h_dim, img_size, num_classes=0):
self.disc = nn.Sequential(
ConvNormAct(img_channels, h_dim, 'sn', 'down', activation='lrelu', normalization='bn'),
SA_Conv2d(h_dim),
nn.BatchNorm2d(h_dim),
nn.LeakyReLU(0.2),
ConvNormAct(h_dim, h_dim*2, 'sn', 'down', activation='lrelu', normalization='bn'),
ConvNormAct(h_dim*2, h_dim*4, 'sn', 'down', activation='lrelu', normalization='bn'),
ConvNormAct(h_dim*4, h_dim*8, 'sn', 'down', activation='lrelu', normalization='bn'),
Expand Down Expand Up @@ -46,6 +48,8 @@ def __init__(self, h_dim, z_dim, img_channels, img_size, num_classes=0):
ConditionalConvBNAct(h_dim*4, h_dim*2, 'sn', 'up', activation='relu', normalization='bn', num_classes=num_classes),
ConditionalConvBNAct(h_dim*2, h_dim, 'sn', 'up', activation='relu', normalization='bn', num_classes=num_classes),
SA_Conv2d(h_dim),
ConditionalNorm(num_classes, h_dim) if num_classes > 0 else nn.BatchNorm2d(h_dim),
nn.ReLU(),
SN_ConvTranspose2d(in_channels=h_dim, out_channels=img_channels, kernel_size=4,
stride=2, padding=1),
nn.Tanh()
Expand All @@ -61,7 +65,7 @@ def forward(self, x, y=None):
self.min_hw,
self.min_hw)
for layer in self.gen:
if isinstance(layer, ConditionalConvBNAct) and y is not None:
if isinstance(layer, (ConditionalNorm, ConditionalConvBNAct)) and y is not None:
x_hat = layer(x_hat, y)
else:
x_hat = layer(x_hat)
Expand Down
8 changes: 5 additions & 3 deletions td3/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,16 @@ def __init__(self, env, alpha, beta, hidden_dims, tau,

# networks & optimizers
if img_input:
self.actor = ImageActor(in_channels, n_actions, self.max_action, order, depth, multiplier, 'actor').to(self.device)
self.actor = ImageActor(in_channels, n_actions, hidden_dim, self.max_action, order, depth, multiplier, 'actor').to(self.device)
self.critic_1 = ImageCritic(in_channels, n_actions, hidden_dim, action_embed_dim, order, depth, multiplier, 'critic_1').to(self.device)
self.critic_2 = ImageCritic(in_channels, n_actions, hidden_dim, action_embed_dim, order, depth, multiplier, 'critic_2').to(self.device)

self.target_actor = ImageActor(in_channels, n_actions, self.max_action, order, depth, multiplier, 'target_actor').to(self.device)
self.target_actor = ImageActor(in_channels, n_actions, hidden_dim, self.max_action, order, depth, multiplier, 'target_actor').to(self.device)
self.target_critic_1 = ImageCritic(in_channels, n_actions, hidden_dim, action_embed_dim, order, depth, multiplier, 'target_critic_1').to(self.device)
self.target_critic_2 = ImageCritic(in_channels, n_actions, hidden_dim, action_embed_dim, order, depth, multiplier, 'target_critic_2').to(self.device)

print('actor')
print(self.actor)

# physics networks
else:
self.actor = Actor(state_space, hidden_dims, n_actions, self.max_action, 'actor').to(self.device)
Expand Down
99 changes: 50 additions & 49 deletions td3/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,20 @@
parser.add_argument('--img_input', action="store_true", default=False, help='Use image as states')
parser.add_argument('--in_channels', type=int, default=3, help='Number of image channels for image input')
parser.add_argument('--depth', type=int, default=3, help='Depth for CNN architecture for image input')
parser.add_argument('--multiplier', type=int, default=32, help='Channel multiplier for CNN architecture for image input')
parser.add_argument('--multiplier', type=int, default=16, help='Channel multiplier for CNN architecture for image input')
parser.add_argument('--order', type=int, default=3, help='Store past (order) of frames for image input')
parser.add_argument('--action_embed_dim', type=int, default=32, help='Embedding dimension for actions for image input')
parser.add_argument('--hidden_dim', type=int, default=512, help='List of hidden dims for embedding networks')
parser.add_argument('--action_embed_dim', type=int, default=256, help='Embedding dimension for actions for image input')
parser.add_argument('--hidden_dim', type=int, default=256, help='Hidden dims for embedding networks')
parser.add_argument('--crop_dim', type=int, default=32, help='Crop dim for image inputs')

# training hp params
parser.add_argument('--n_episodes', type=int, default=1000, help='Number of episodes')
parser.add_argument('--batch_size', type=int, default=512, help='Batch size')
parser.add_argument('--alpha', type=float, default=0.001, help='Learning rate actor')
parser.add_argument('--beta', type=float, default=0.001, help='Learning rate critic')
parser.add_argument('--alpha', type=float, default=3e-4, help='Learning rate actor')
parser.add_argument('--beta', type=float, default=3e-4, help='Learning rate critic')
parser.add_argument('--warmup', type=int, default=1000, help='Number of warmup steps')
parser.add_argument('--d', type=int, default=2, help='Skip iteration')
parser.add_argument('--max_size', type=int, default=1000000, help='Replay buffer size')
parser.add_argument('--max_size', type=int, default=100000, help='Replay buffer size')
parser.add_argument('--no_render', action="store_true", default=False, help='Whether to render')
parser.add_argument('--window_size', type=int, default=100, help='Score tracking moving average window size')

Expand All @@ -58,6 +58,7 @@
import pyvirtualdisplay
_display = pyvirtualdisplay.Display(visible=False, size=(1400, 900))
_ = _display.start()

# paths
Path(args.checkpoint_dir).mkdir(parents=True, exist_ok=True)
Path(args.log_dir).mkdir(parents=True, exist_ok=True)
Expand All @@ -78,46 +79,46 @@
score_history = deque([], maxlen=args.window_size)
episodes = tqdm(range(args.n_episodes))

for e in episodes:
# resetting
state = env.reset()
if args.img_input:
state_queue = deque(
[preprocess_img(state['pixels'], args.crop_dim) for _ in range(args.order)],
maxlen=args.order)
state = torch.cat(list(state_queue), 1).cpu().numpy()
done, score = False, 0

while not done:
action = agent.choose_action(state)
state_, reward, done, _ = env.step(action)
if isinstance(reward, np.ndarray):
reward = reward[0]
if args.img_input:
state_queue.append(preprocess_img(state_['pixels'], args.crop_dim))
state_ = torch.cat(list(state_queue), 1).cpu().numpy()
agent.remember(state, action, reward, state_, done)
agent.learn()

# reset, log & render
score += reward
state = state_
episodes.set_postfix({'Reward': reward})
if args.no_render:
continue
env.render()

# logging
score_history.append(score)
moving_avg = sum(score_history) / len(score_history)
agent.add_scalar('Average Score', moving_avg, global_step=e)

# save weights @ best score
if moving_avg > best_score:
best_score = moving_avg
agent.save_networks()

tqdm.write(f'Episode: {e + 1}/{args.n_episodes}, \
Episode Score: {score}, \
Average Score: {moving_avg}, \
Best Score: {best_score}')
# for e in episodes:
# # resetting
# state = env.reset()
# if args.img_input:
# state_queue = deque(
# [preprocess_img(state['pixels'], args.crop_dim) for _ in range(args.order)],
# maxlen=args.order)
# state = torch.cat(list(state_queue), 1).cpu().numpy()
# done, score = False, 0

# while not done:
# action = agent.choose_action(state)
# state_, reward, done, _ = env.step(action)
# if isinstance(reward, np.ndarray):
# reward = reward[0]
# if args.img_input:
# state_queue.append(preprocess_img(state_['pixels'], args.crop_dim))
# state_ = torch.cat(list(state_queue), 1).cpu().numpy()
# agent.remember(state, action, reward, state_, done)
# agent.learn()

# # reset, log & render
# score += reward
# state = state_
# episodes.set_postfix({'Reward': reward, 'Iteration': agent.time_step})
# if args.no_render:
# continue
# env.render()

# # logging
# score_history.append(score)
# moving_avg = sum(score_history) / len(score_history)
# agent.add_scalar('Average Score', moving_avg, global_step=e)

# # save weights @ best score
# if moving_avg > best_score:
# best_score = moving_avg
# agent.save_networks()

# tqdm.write(f'Episode: {e + 1}/{args.n_episodes}, \
# Episode Score: {score}, \
# Average Score: {moving_avg}, \
# Best Score: {best_score}')
15 changes: 9 additions & 6 deletions td3/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def forward(self, state):


class ImageActor(nn.Module):
def __init__(self, in_channels, n_actions, max_action, order, depth, multiplier, name):
def __init__(self, in_channels, n_actions, hidden_dim, max_action, order, depth, multiplier, name):
super().__init__()
self.name = name
self.max_action = max_action
Expand All @@ -65,10 +65,13 @@ def __init__(self, in_channels, n_actions, max_action, order, depth, multiplier,
*convs,
nn.AdaptiveAvgPool2d(1),
nn.Flatten())
self.fc = nn.Linear(order*ch, n_actions)
self.fc = nn.Sequential(
nn.Linear(order * ch, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, n_actions))

def forward(self, imgs):
img_feature = [self.convs(imgs[:, i*self.order:(i+1)*self.order, :, :]) for i in range(self.order)]
img_feature = [self.convs(img) for img in imgs.chunk(self.order, 1)]
img_feature = torch.cat(img_feature, 1)
return torch.tanh(self.fc(img_feature)) * self.max_action

Expand Down Expand Up @@ -104,9 +107,9 @@ def __init__(self, in_channels, n_actions, hidden_dim, action_embed_dim, order,
nn.Linear(hidden_dim, 1)
)

def forward(self, state, action):
img_embedding = [self.avg_pool(self.convs(
state[:, i*self.order:(i+1)*self.order, :, :])).squeeze() for i in range(self.order)]
def forward(self, states, action):
img_embedding = [self.avg_pool(
self.convs(state)).squeeze() for state in states.chunk(self.order, 1)]
img_embedding = torch.cat(img_embedding, 1)
action_embedding = self.action_head(action)
combined_embedding = torch.cat([img_embedding, action_embedding], dim=1)
Expand Down
20 changes: 13 additions & 7 deletions vae/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@


class VAELoss(nn.Module):
def __init__(self, recon_type, beta):
def __init__(self, recon_type, beta, reduction='sum'):
super().__init__()
self.recon_type = recon_type
if recon_type == 'l2':
self.recon = nn.MSELoss(reduction='sum')
self.recon = nn.MSELoss(reduction=reduction)
elif recon_type == 'l1':
self.recon = nn.L1Loss(reduction='sum')
self.recon = nn.L1Loss(reduction=reduction)
elif recon_type == 'bce':
self.recon = nn.BCELoss(reduction='sum')
self.recon = nn.BCELoss(reduction=reduction)
else:
raise NotImplementedError
self.beta = beta
Expand All @@ -29,15 +29,16 @@ def forward(self, x, x_hat, mu, logvar):


class GroupSparsityLoss(nn.Module):
def __init__(self, n_elements):
def __init__(self, n_elements, rho=0.05):
super().__init__()
self.n_elements = n_elements
self.rho = torch.tensor(rho)

def forward(self, z):
rho, n_elements = self.rho, self.n_elements
# flatten latent variables & get dims
z = z.view(z.size(0), -1)
batch_size, z_dim = z.size()
n_elements = self.n_elements
assert z_dim % n_elements == 0
groups = z_dim // n_elements

Expand All @@ -47,4 +48,9 @@ def forward(self, z):
).unsqueeze(0).repeat(batch_size, 1, 1).to(z.device)
z_groups = z.unsqueeze(1).repeat(1, groups, 1)
masked_z = z_groups * mask
return masked_z.norm(p=2, dim=-1).sum(-1).mean()
rho_hat = masked_z.norm(p=2, dim=-1)

sparsity_penalty = (
(rho * (torch.log(rho) - torch.log(rho_hat))) + (
(1 - rho) * (torch.log(1 - rho) - torch.log(1 - rho_hat)))).sum(-1)
return sparsity_penalty.mean()
17 changes: 8 additions & 9 deletions vae/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,16 @@
parser.add_argument('--n_pelements', type=int, default=4096, help='elements per gorup for projected')

# loss fn
parser.add_argument('--beta', type=float, default=5., help='Beta hyperparam for KLD Loss')
parser.add_argument('--gamma', type=float, default=1e6, help='gamma hyperparam for Sparsity Loss')
parser.add_argument('--beta', type=float, default=0.1, help='Beta hyperparam for KLD Loss')
parser.add_argument('--gamma', type=float, default=1e-4, help='gamma hyperparam for Sparsity Loss')
parser.add_argument('--recon', type=str, default='bce', help='Reconstruction loss type [bce, l2]')
parser.add_argument('--reduction', type=str, default='mean', help='Loss reduction method [mean, sum]')

# training hyperparams
parser.add_argument('--device_ids', type=list, default=[0, 1], help='List of GPU devices')
parser.add_argument('--lr', type=float, default=0.0005, help='Learning rate for generators')
parser.add_argument('--betas', type=tuple, default=(0.5, 0.999), help='Betas for Adam optimizer')
parser.add_argument('--n_epochs', type=int, default=200, help='Number of epochs')
parser.add_argument('--n_epochs', type=int, default=50, help='Number of epochs')
parser.add_argument('--batch_size', type=int, default=256, help='Batch size')
parser.add_argument('--data_parallel', action="store_true", default=False, help='train with data parallel')

Expand Down Expand Up @@ -78,7 +79,7 @@
model.apply(initialize_modules)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=args.betas)
scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lambda epoch: 0.995)
criterion = VAELoss(args.recon, args.beta)
criterion = VAELoss(args.recon, args.beta, args.reduction)

# fixed z to see how model changes on the same latent vectors
fixed_z = torch.randn(args.sample_size, args.z_dim).to(device)
Expand All @@ -103,11 +104,9 @@

# if there is sparsity loss
if args.plus:
std = logvar.mul(0.5).exp_()
sparsity_loss = GroupSparsityLoss(args.n_zelements)(mu) + \
GroupSparsityLoss(args.n_pelements)(z_p)
loss += args.gamma * sparsity_loss
s_loss.append(args.gamma * sparsity_loss.item())
sparsity_loss = args.gamma * GroupSparsityLoss(args.n_zelements)(mu)
loss += sparsity_loss
s_loss.append(sparsity_loss.item())

# logging and updating parameters
losses.append(loss.item())
Expand Down
22 changes: 13 additions & 9 deletions vae/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,14 @@ class VAE_Plus(VAE):
def __init__(self, z_dim, model_dim, img_size, img_channels):
super().__init__(z_dim, model_dim, img_size, img_channels)
self.encoder = nn.Sequential(
ConvNormAct(img_channels, model_dim, 'sn', 'down', activation='lrelu'),
ConvNormAct(img_channels, model_dim, 'sn', 'down', activation='relu'),
SA_Conv2d(model_dim),
ConvNormAct(model_dim, model_dim * 2, 'sn', 'down', activation='lrelu'),
ConvNormAct(model_dim * 2, model_dim * 4, 'sn', 'down', activation='lrelu'),
ConvNormAct(model_dim * 4, model_dim * 8, 'sn', None, activation='lrelu'),
ConvNormAct(model_dim * 8, model_dim * 8, 'sn', None, activation='lrelu'),
nn.BatchNorm2d(model_dim),
nn.ReLU(),
ConvNormAct(model_dim, model_dim * 2, 'sn', 'down', activation='relu'),
ConvNormAct(model_dim * 2, model_dim * 4, 'sn', 'down', activation='relu'),
ConvNormAct(model_dim * 4, model_dim * 8, 'sn', None, activation='relu'),
ConvNormAct(model_dim * 8, model_dim * 8, 'sn', None, activation='relu'),
nn.Flatten(),
nn.Linear((img_size // (2**3))**2 * model_dim * 8, z_dim * 2)
)
Expand All @@ -74,11 +76,13 @@ def __init__(self, z_dim, model_dim, img_size, img_channels):
nn.LeakyReLU(0.2)
)
self.decoder = nn.Sequential(
ConvNormAct(model_dim * 8, model_dim * 8, 'sn', None, activation='lrelu'),
ConvNormAct(model_dim * 8, model_dim * 4, 'sn', None, activation='lrelu'),
ConvNormAct(model_dim * 4, model_dim * 2, 'sn', 'up', activation='lrelu'),
ConvNormAct(model_dim * 2, model_dim, 'sn', 'up', activation='lrelu'),
ConvNormAct(model_dim * 8, model_dim * 8, 'sn', None, activation='relu'),
ConvNormAct(model_dim * 8, model_dim * 4, 'sn', None, activation='relu'),
ConvNormAct(model_dim * 4, model_dim * 2, 'sn', 'up', activation='relu'),
ConvNormAct(model_dim * 2, model_dim, 'sn', 'up', activation='relu'),
SA_Conv2d(model_dim),
nn.BatchNorm2d(model_dim),
nn.ReLU(),
SN_ConvTranspose2d(in_channels=model_dim, out_channels=img_channels, kernel_size=4, stride=2, padding=1),
nn.Tanh()
)
Expand Down

0 comments on commit f24e842

Please sign in to comment.