Skip to content

Commit

Permalink
Implement general edit direction parser insted of hardcoding
Browse files Browse the repository at this point in the history
  • Loading branch information
10maurycy10 committed Feb 14, 2023
1 parent c792398 commit 0d14bdc
Showing 1 changed file with 10 additions and 12 deletions.
22 changes: 10 additions & 12 deletions src/utils/edit_directions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import os
import torch

if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"


"""
This function takes in a task name and returns the direction in the embedding space that transforms class A to class B for the given task.
Expand All @@ -15,15 +20,8 @@
>>> construct_direction("cat2dog")
"""
def construct_direction(task_name):
if task_name=="cat2dog":
emb_dir = f"assets/embeddings_sd_1.4"
embs_a = torch.load(os.path.join(emb_dir, f"cat.pt"))
embs_b = torch.load(os.path.join(emb_dir, f"dog.pt"))
return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
elif task_name=="dog2cat":
emb_dir = f"assets/embeddings_sd_1.4"
embs_a = torch.load(os.path.join(emb_dir, f"dog.pt"))
embs_b = torch.load(os.path.join(emb_dir, f"cat.pt"))
return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
else:
raise NotImplementedError
(src, dst) = task_name.split("2")
emb_dir = f"assets/embeddings_sd_1.4"
embs_a = torch.load(os.path.join(emb_dir, f"{src}.pt"), map_location=device)
embs_b = torch.load(os.path.join(emb_dir, f"{dst}.pt"), map_location=device)
return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)

0 comments on commit 0d14bdc

Please sign in to comment.