Skip to content

Commit

Permalink
Update evals to not load retrieval embeddings.
Browse files Browse the repository at this point in the history
  • Loading branch information
kohjingyu committed Jul 17, 2023
1 parent 0bd94da commit c7de07a
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 5 deletions.
2 changes: 1 addition & 1 deletion evals/generate_visdial_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
os.makedirs(output_dir, exist_ok=True)
print('Saving to', output_dir)

model = models.load_gill('checkpoints/gill_opt/')
model = models.load_gill('checkpoints/gill_opt/', load_ret_embs=False)
g_cuda = torch.Generator(device='cuda').manual_seed(42) # Fix the random seed.

# Load VisDial data.
Expand Down
2 changes: 1 addition & 1 deletion evals/generate_vist_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
os.makedirs(output_dir, exist_ok=True)
print('Saving to', output_dir)

model = models.load_gill('checkpoints/gill_opt/')
model = models.load_gill('checkpoints/gill_opt/', load_ret_embs=False)
g_cuda = torch.Generator(device='cuda').manual_seed(42) # Fix the random seed.

# Load VIST data.
Expand Down
8 changes: 5 additions & 3 deletions gill/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,7 +808,7 @@ def get_log_likelihood_scores(
return -outputs.loss.item()


def load_gill(model_dir: str) -> GILL:
def load_gill(model_dir: str, load_ret_embs: bool = True) -> GILL:
model_args_path = os.path.join(model_dir, 'model_args.json')
model_ckpt_path = os.path.join(model_dir, 'pretrained_ckpt.pth.tar')
embs_paths = [s for s in glob.glob(os.path.join(model_dir, 'cc3m*.npy'))]
Expand All @@ -817,8 +817,10 @@ def load_gill(model_dir: str) -> GILL:
raise ValueError(f'model_args.json does not exist in {model_dir}.')
if not os.path.exists(model_ckpt_path):
raise ValueError(f'pretrained_ckpt.pth.tar does not exist in {model_dir}.')
if len(embs_paths) == 0:
print(f'cc3m.npy files do not exist in {model_dir}. Running the model without retrieval.')
if not load_ret_embs or len(embs_paths) == 0:
if len(embs_paths) == 0:
print(f'cc3m.npy files do not exist in {model_dir}.')
print('Running the model without retrieval.')
path_array, emb_matrix = None, None
else:
# Load embeddings.
Expand Down

0 comments on commit c7de07a

Please sign in to comment.