Skip to content
This repository has been archived by the owner on Feb 21, 2023. It is now read-only.

Commit

Permalink
working on model
Browse files Browse the repository at this point in the history
  • Loading branch information
zarzouram committed Jun 23, 2022
1 parent f366978 commit 47c8c7a
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 93 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -405,4 +405,5 @@ data/bert_data/based_uncased/config.json
data/bert_data/based_uncased/.gitattributes
data/models/*
!data/models/slim_18-03-08h06_7485.pt
.test
.test
.vscode
8 changes: 4 additions & 4 deletions codes/models/SLIM.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from torch import nn
from torch import Tensor

from models.transformer_encoder import CaptionEncoder
from models.representation import RepresentationNetwork
from models.generation import DRAW
from codes.models.transformer_encoder import CaptionEncoder
from codes.models.representation import RepresentationNetwork
from codes.models.generation import DRAW


class SLIM(nn.Module):
Expand Down Expand Up @@ -88,7 +88,7 @@ def init_weights(self):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)

def forward(self, batch: List[str, Tensor]) -> Tensor:
def forward(self, batch: List[Tensor]) -> Tensor:

# Sizes:
# ------
Expand Down
2 changes: 1 addition & 1 deletion codes/models/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch import nn
from torch.distributions import Normal, kl_divergence

from layers.conv_lstm_simple import ConvLSTMCell
from codes.layers.conv_lstm_simple import ConvLSTMCell


class DRAW(nn.Module):
Expand Down
137 changes: 50 additions & 87 deletions run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,16 @@
import json
from tqdm import tqdm

from typing import List, Tuple
from typing import List

import random

import torch
from torch.utils.data import DataLoader
from torch import optim
import numpy as np
# from torch import optim
# import numpy as np

from codes.dataset.dataset_batcher import SlimDataset
from codes.models.SLIM import SLIM
from codes.dataset.preprocessing import get_mini_batch
from codes.helpers.train_helper import Trainer
from codes.helpers.scheduler import LinearDecayLR, VarAnnealer
# from codes.helpers.train_helper import Trainer

from codes.utils.gpu_cuda_helper import select_device
from codes.utils.utils import seed_everything
Expand Down Expand Up @@ -50,10 +46,11 @@ def parse_arguments():
default="draw", #
help="pretraining a submodule, {draw, caption_encoder}")

parser.add_argument("--gpu",
type=int,
default=-1,
help="GPU device to be used")
parser.add_argument(
'--device',
type=str,
default="gpu", # gpu, cpu
help='Device to be used either gpu or cpu.')

args = parser.parse_args()

Expand All @@ -67,33 +64,6 @@ def load_config_file(config_path: str) -> List[dict]:
return configs


# def load_model(model_parameters: dict,
# scheduler_param: dict,
# checkpoint_path: str = ""):

# lr_init = scheduler_param["lr_init"]
# model = SLIM(model_parameters)
# optimizer = optim.Adam(model.parameters(), lr=lr_init)
# scheduler = LinearDecayLR(optimizer, **scheduler_param)
# model_data = None
# # Variance scales
# var_scale = VarAnnealer(**configs["var_scale_parm"])

# if checkpoint_path != "":
# model_data = torch.load(checkpoint_path,
# map_location=torch.device("cpu"))
# model_state_dict = model_data["model_state_dict"]
# optimizer_state_dict = model_data["optimizer_state_dict"]
# scheduler_state_dict = model_data["scheduler_state_dict"]
# model.load_state_dict(model_state_dict)
# optimizer.load_state_dict(optimizer_state_dict)
# scheduler.load_state_dict(scheduler_state_dict)
# var_scale.scale = model_data["var_scale"]
# var_scale.t = model_data["steps"] + 1

# return model, optimizer, scheduler, var_scale, model_data


def run_train(train,
train_iter,
val_iter,
Expand Down Expand Up @@ -285,51 +255,44 @@ def run_train(train,
hyperparameters = configs["model_hyperparameter"]
model = SLIM(params=hyperparameters, pretrain=pretrain)

scheduler_parm = configs["scheduler_parm"]
model, optimizer, scheduler, var_scale, model_data = load_model(
hyperparameters,
device,
scheduler_parm,
)

# load trianer class
train_param = configs["train_param"]
# number of steps per epoch
epoch_intrv = int(train_param["samples_num"] /
train_param["mini_batch_size"])
trainer = Trainer(model,
device,
epoch_interval=epoch_intrv,
save_path=configs["checkpoints_dir"])
if args.checkpoint_model != "": # resume from checkpoint
trainer.global_steps = model_data["steps"] + 1
trainer.best_loss = model_data["loss"]
trainer.epoch = model_data["epoch"]

# Init Visualization
env_name = args.plot_env_name
vis = Visualizations(env_name=env_name)
legend = [["Train", "Validation"]]
title = [f"Loss Plot (mean every 1 epoch/{epoch_intrv} steps)"]
xlabel = [f"Epoch ({epoch_intrv} steps)"]
ylabel = ["ELBO Loss"]
win_name = [f"{env_name}_total_Loss"]
if args.plot_loss_comp.lower() == "y":
for loss_type in ["Const_Loss", "KlD_Loss"]:
legend.append(["Train", "Validation"])
title.append(
f"{loss_type} Plot (mean every 1 epoch/{epoch_intrv} steps)")
xlabel.append(f"Epoch ({epoch_intrv} steps)")
ylabel.append(f"{loss_type}")
win_name.append(f"{env_name}_{loss_type}")
opt_win = {
"win_name": win_name,
"xlabel": xlabel,
"ylabel": ylabel,
"title": title,
"legend": legend
}
vis.add_wins(**opt_win)

run_train(trainer, train_iter, val_iter, model, optimizer, scheduler,
configs, var_scale, vis, win_name)
# # load trianer class
# train_param = configs["train_param"]
# # number of steps per epoch
# epoch_intrv = int(train_param["samples_num"] /
# train_param["mini_batch_size"])
# trainer = Trainer(model,
# device,
# epoch_interval=epoch_intrv,
# save_path=configs["checkpoints_dir"])
# if args.checkpoint_model != "": # resume from checkpoint
# trainer.global_steps = model_data["steps"] + 1
# trainer.best_loss = model_data["loss"]
# trainer.epoch = model_data["epoch"]

# # Init Visualization
# env_name = args.plot_env_name
# # vis = Visualizations(env_name=env_name)
# legend = [["Train", "Validation"]]
# title = [f"Loss Plot (mean every 1 epoch/{epoch_intrv} steps)"]
# xlabel = [f"Epoch ({epoch_intrv} steps)"]
# ylabel = ["ELBO Loss"]
# win_name = [f"{env_name}_total_Loss"]
# if args.plot_loss_comp.lower() == "y":
# for loss_type in ["Const_Loss", "KlD_Loss"]:
# legend.append(["Train", "Validation"])
# title.append(
# f"{loss_type} Plot (mean every 1 epoch/{epoch_intrv} steps)")
# xlabel.append(f"Epoch ({epoch_intrv} steps)")
# ylabel.append(f"{loss_type}")
# win_name.append(f"{env_name}_{loss_type}")
# opt_win = {
# "win_name": win_name,
# "xlabel": xlabel,
# "ylabel": ylabel,
# "title": title,
# "legend": legend
# }
# vis.add_wins(**opt_win)

# run_train(trainer, train_iter, val_iter, model, optimizer, scheduler,
# configs, var_scale, vis, win_name)

0 comments on commit 47c8c7a

Please sign in to comment.