Skip to content
This repository has been archived by the owner on Feb 21, 2023. It is now read-only.

Commit

Permalink
working on model
Browse files Browse the repository at this point in the history
  • Loading branch information
zarzouram committed Jun 22, 2022
1 parent c7cfb31 commit 167efd8
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 34 deletions.
33 changes: 0 additions & 33 deletions codes/dataset/dataset_batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,36 +70,3 @@ def __getitem__(self, i: int):

def collate_fn(self):
pass


# class collate_fn(object):

# def __init__(self, max_len, pad_id=0):
# self.max_len = max_len
# self.pad = pad_id

# def __call__(self, batch) -> Tuple[Tensor, Tensor, Tensor]:
# """
# Padds batch of variable lengthes to a fixed length (max_len)
# """
# X, y, ls = zip(*batch)
# X: Tuple[Tensor]
# y: Tuple[Tensor]
# ls: Tuple[Tensor]

# # pad tuple
# # [B, max_seq_len, captns_num=5]
# ls = torch.stack(ls) # (B, num_captions)
# y = pad_sequence(y, batch_first=True, padding_value=self.pad)

# # pad to the max len
# pad_right = self.max_len - y.size(1)
# if pad_right > 0:
# # [B, captns_num, max_seq_len]
# y = y.permute(0, 2, 1) # type: Tensor
# y = ConstantPad1d((0, pad_right), value=self.pad)(y)
# y = y.permute(0, 2, 1) # [B, max_len, captns_num]

# X = torch.stack(X) # (B, 3, 256, 256)

# return X, y, ls
5 changes: 4 additions & 1 deletion run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,10 @@ def run_train(train,
hyperparameters = configs["model_hyperparameter"]
scheduler_parm = configs["scheduler_parm"]
model, optimizer, scheduler, var_scale, model_data = load_model(
hyperparameters, device, scheduler_parm, )
hyperparameters,
device,
scheduler_parm,
)

# load trianer class
train_param = configs["train_param"]
Expand Down

0 comments on commit 167efd8

Please sign in to comment.