Skip to content

Commit

Permalink
Update visual_chatgpt.py
Browse files Browse the repository at this point in the history
  • Loading branch information
BinayakJha authored Mar 12, 2023
1 parent 31433ee commit a1412fb
Showing 1 changed file with 96 additions and 39 deletions.
135 changes: 96 additions & 39 deletions visual_chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,29 +84,26 @@ def cut_dialogue_history(history_memory, keep_last_n_words=500):
print(f"hitory_memory:{history_memory}, n_tokens: {n_tokens}")
if n_tokens < keep_last_n_words:
return history_memory
else:
paragraphs = history_memory.split('\n')
last_n_tokens = n_tokens
while last_n_tokens >= keep_last_n_words:
last_n_tokens = last_n_tokens - len(paragraphs[0].split(' '))
paragraphs = paragraphs[1:]
return '\n' + '\n'.join(paragraphs)
paragraphs = history_memory.split('\n')
last_n_tokens = n_tokens
while last_n_tokens >= keep_last_n_words:
last_n_tokens -= len(paragraphs[0].split(' '))
paragraphs = paragraphs[1:]
return '\n' + '\n'.join(paragraphs)

def get_new_image_name(org_img_name, func_name="update"):
head_tail = os.path.split(org_img_name)
head = head_tail[0]
tail = head_tail[1]
name_split = tail.split('.')[0].split('_')
this_new_uuid = str(uuid.uuid4())[0:4]
this_new_uuid = str(uuid.uuid4())[:4]
if len(name_split) == 1:
most_org_file_name = name_split[0]
recent_prev_file_name = name_split[0]
new_file_name = '{}_{}_{}_{}.png'.format(this_new_uuid, func_name, recent_prev_file_name, most_org_file_name)
else:
assert len(name_split) == 4
most_org_file_name = name_split[3]
recent_prev_file_name = name_split[0]
new_file_name = '{}_{}_{}_{}.png'.format(this_new_uuid, func_name, recent_prev_file_name, most_org_file_name)
recent_prev_file_name = name_split[0]
new_file_name = f'{this_new_uuid}_{func_name}_{recent_prev_file_name}_{most_org_file_name}.png'
return os.path.join(head, new_file_name)

def create_model(config_path, device):
Expand Down Expand Up @@ -146,7 +143,7 @@ def inference(self, image_path, text):

class ImageEditing:
def __init__(self, device):
print("Initializing StableDiffusionInpaint to %s" % device)
print(f"Initializing StableDiffusionInpaint to {device}")
self.device = device
self.mask_former = MaskFormer(device=self.device)
self.inpainting = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting",).to(device)
Expand All @@ -168,7 +165,7 @@ def replace_part_of_image(self, input):

class Pix2Pix:
def __init__(self, device):
print("Initializing Pix2Pix to %s" % device)
print(f"Initializing Pix2Pix to {device}")
self.device = device
self.pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained("timbrooks/instruct-pix2pix", torch_dtype=torch.float16, safety_checker=None).to(device)
self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe.scheduler.config)
Expand All @@ -185,7 +182,7 @@ def inference(self, inputs):

class T2I:
def __init__(self, device):
print("Initializing T2I to %s" % device)
print(f"Initializing T2I to {device}")
self.device = device
self.pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
self.text_refine_tokenizer = AutoTokenizer.from_pretrained("Gustavosta/MagicPrompt-Stable-Diffusion")
Expand All @@ -194,7 +191,7 @@ def __init__(self, device):
self.pipe.to(device)

def inference(self, text):
image_filename = os.path.join('image', str(uuid.uuid4())[0:8] + ".png")
image_filename = os.path.join('image', f"{str(uuid.uuid4())[:8]}.png")
refined_text = self.text_refine_gpt2_pipe(text)[0]["generated_text"]
print(f'{text} refined to {refined_text}')
image = self.pipe(refined_text).images[0]
Expand All @@ -204,16 +201,15 @@ def inference(self, text):

class ImageCaptioning:
def __init__(self, device):
print("Initializing ImageCaptioning to %s" % device)
print(f"Initializing ImageCaptioning to {device}")
self.device = device
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(self.device)

def inference(self, image_path):
inputs = self.processor(Image.open(image_path), return_tensors="pt").to(self.device)
out = self.model.generate(**inputs)
captions = self.processor.decode(out[0], skip_special_tokens=True)
return captions
return self.processor.decode(out[0], skip_special_tokens=True)

class image2canny:
def __init__(self):
Expand Down Expand Up @@ -268,7 +264,14 @@ def inference(self, inputs):
seed_everything(self.seed)
if self.save_memory:
self.model.low_vram_shift(is_diffusing=False)
cond = {"c_concat": [control], "c_crossattn": [self.model.get_learned_conditioning([prompt + ', ' + self.a_prompt] * self.num_samples)]}
cond = {
"c_concat": [control],
"c_crossattn": [
self.model.get_learned_conditioning(
[f'{prompt}, {self.a_prompt}'] * self.num_samples
)
],
}
un_cond = {"c_concat": None if self.guess_mode else [control], "c_crossattn": [self.model.get_learned_conditioning([self.n_prompt] * self.num_samples)]}
shape = (4, H // 8, W // 8)
self.model.control_scales = [self.strength * (0.825 ** float(12 - i)) for i in range(13)] if self.guess_mode else ([self.strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
Expand Down Expand Up @@ -339,7 +342,14 @@ def inference(self, inputs):
seed_everything(self.seed)
if self.save_memory:
self.model.low_vram_shift(is_diffusing=False)
cond = {"c_concat": [control], "c_crossattn": [self.model.get_learned_conditioning([prompt + ', ' + self.a_prompt] * self.num_samples)]}
cond = {
"c_concat": [control],
"c_crossattn": [
self.model.get_learned_conditioning(
[f'{prompt}, {self.a_prompt}'] * self.num_samples
)
],
}
un_cond = {"c_concat": None if self.guess_mode else [control], "c_crossattn": [self.model.get_learned_conditioning([self.n_prompt] * self.num_samples)]}
shape = (4, H // 8, W // 8)
self.model.control_scales = [self.strength * (0.825 ** float(12 - i)) for i in range(13)] if self.guess_mode else ([self.strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
Expand All @@ -348,7 +358,7 @@ def inference(self, inputs):
self.model.low_vram_shift(is_diffusing=False)
x_samples = self.model.decode_first_stage(samples)
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).\
cpu().numpy().clip(0,255).astype(np.uint8)
cpu().numpy().clip(0,255).astype(np.uint8)
updated_image_path = get_new_image_name(image_path, func_name="line2image")
real_image = Image.fromarray(x_samples[0]) # default the index0 image
real_image.save(updated_image_path)
Expand Down Expand Up @@ -408,7 +418,14 @@ def inference(self, inputs):
seed_everything(self.seed)
if self.save_memory:
self.model.low_vram_shift(is_diffusing=False)
cond = {"c_concat": [control], "c_crossattn": [self.model.get_learned_conditioning([prompt + ', ' + self.a_prompt] * self.num_samples)]}
cond = {
"c_concat": [control],
"c_crossattn": [
self.model.get_learned_conditioning(
[f'{prompt}, {self.a_prompt}'] * self.num_samples
)
],
}
un_cond = {"c_concat": None if self.guess_mode else [control], "c_crossattn": [self.model.get_learned_conditioning([self.n_prompt] * self.num_samples)]}
shape = (4, H // 8, W // 8)
self.model.control_scales = [self.strength * (0.825 ** float(12 - i)) for i in range(13)] if self.guess_mode else ([self.strength] * 13)
Expand Down Expand Up @@ -485,7 +502,14 @@ def inference(self, inputs):
seed_everything(self.seed)
if self.save_memory:
self.model.low_vram_shift(is_diffusing=False)
cond = {"c_concat": [control], "c_crossattn": [self.model.get_learned_conditioning([prompt + ', ' + self.a_prompt] * self.num_samples)]}
cond = {
"c_concat": [control],
"c_crossattn": [
self.model.get_learned_conditioning(
[f'{prompt}, {self.a_prompt}'] * self.num_samples
)
],
}
un_cond = {"c_concat": None if self.guess_mode else [control], "c_crossattn": [self.model.get_learned_conditioning([self.n_prompt] * self.num_samples)]}
shape = (4, H // 8, W // 8)
self.model.control_scales = [self.strength * (0.825 ** float(12 - i)) for i in range(13)] if self.guess_mode else ([self.strength] * 13)
Expand Down Expand Up @@ -555,7 +579,14 @@ def inference(self, inputs):
seed_everything(self.seed)
if self.save_memory:
self.model.low_vram_shift(is_diffusing=False)
cond = {"c_concat": [control], "c_crossattn": [ self.model.get_learned_conditioning([prompt + ', ' + self.a_prompt] * self.num_samples)]}
cond = {
"c_concat": [control],
"c_crossattn": [
self.model.get_learned_conditioning(
[f'{prompt}, {self.a_prompt}'] * self.num_samples
)
],
}
un_cond = {"c_concat": None if self.guess_mode else [control], "c_crossattn": [self.model.get_learned_conditioning([self.n_prompt] * self.num_samples)]}
shape = (4, H // 8, W // 8)
self.model.control_scales = [self.strength * (0.825 ** float(12 - i)) for i in range(13)] if self.guess_mode else ([self.strength] * 13)
Expand Down Expand Up @@ -625,7 +656,14 @@ def inference(self, inputs):
seed_everything(self.seed)
if self.save_memory:
self.model.low_vram_shift(is_diffusing=False)
cond = {"c_concat": [control], "c_crossattn": [self.model.get_learned_conditioning([prompt + ', ' + self.a_prompt] * self.num_samples)]}
cond = {
"c_concat": [control],
"c_crossattn": [
self.model.get_learned_conditioning(
[f'{prompt}, {self.a_prompt}'] * self.num_samples
)
],
}
un_cond = {"c_concat": None if self.guess_mode else [control], "c_crossattn": [self.model.get_learned_conditioning([self.n_prompt] * self.num_samples)]}
shape = (4, H // 8, W // 8)
self.model.control_scales = [self.strength * (0.825 ** float(12 - i)) for i in range(13)] if self.guess_mode else ([self.strength] * 13)
Expand Down Expand Up @@ -695,7 +733,14 @@ def inference(self, inputs):
seed_everything(self.seed)
if self.save_memory:
self.model.low_vram_shift(is_diffusing=False)
cond = {"c_concat": [control], "c_crossattn": [ self.model.get_learned_conditioning([prompt + ', ' + self.a_prompt] * self.num_samples)]}
cond = {
"c_concat": [control],
"c_crossattn": [
self.model.get_learned_conditioning(
[f'{prompt}, {self.a_prompt}'] * self.num_samples
)
],
}
un_cond = {"c_concat": None if self.guess_mode else [control], "c_crossattn": [self.model.get_learned_conditioning([self.n_prompt] * self.num_samples)]}
shape = (4, H // 8, W // 8)
self.model.control_scales = [self.strength * (0.825 ** float(12 - i)) for i in range(13)] if self.guess_mode else ([self.strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
Expand Down Expand Up @@ -767,7 +812,14 @@ def inference(self, inputs):
seed_everything(self.seed)
if self.save_memory:
self.model.low_vram_shift(is_diffusing=False)
cond = {"c_concat": [control], "c_crossattn": [self.model.get_learned_conditioning([prompt + ', ' + self.a_prompt] * self.num_samples)]}
cond = {
"c_concat": [control],
"c_crossattn": [
self.model.get_learned_conditioning(
[f'{prompt}, {self.a_prompt}'] * self.num_samples
)
],
}
un_cond = {"c_concat": None if self.guess_mode else [control], "c_crossattn": [self.model.get_learned_conditioning([self.n_prompt] * self.num_samples)]}
shape = (4, H // 8, W // 8)
self.model.control_scales = [self.strength * (0.825 ** float(12 - i)) for i in range(13)] if self.guess_mode else ([self.strength] * 13)
Expand All @@ -783,7 +835,7 @@ def inference(self, inputs):

class BLIPVQA:
def __init__(self, device):
print("Initializing BLIP VQA to %s" % device)
print(f"Initializing BLIP VQA to {device}")
self.device = device
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
self.model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base").to(self.device)
Expand Down Expand Up @@ -901,9 +953,9 @@ def __init__(self):
agent_kwargs={'prefix': VISUAL_CHATGPT_PREFIX, 'format_instructions': VISUAL_CHATGPT_FORMAT_INSTRUCTIONS, 'suffix': VISUAL_CHATGPT_SUFFIX}, )

def run_text(self, text, state):
print("===============Running run_text =============")
print("Inputs:", text, state)
print("======>Previous memory:\n %s" % self.agent.memory)
self._extracted_from_run_image_2(
"===============Running run_text =============", text, state
)
self.agent.memory.buffer = cut_dialogue_history(self.agent.memory.buffer, keep_last_n_words=500)
res = self.agent({"input": text})
print("======>Current memory:\n %s" % self.agent.memory)
Expand All @@ -913,10 +965,10 @@ def run_text(self, text, state):
return state, state

def run_image(self, image, state, txt):
print("===============Running run_image =============")
print("Inputs:", image, state)
print("======>Previous memory:\n %s" % self.agent.memory)
image_filename = os.path.join('image', str(uuid.uuid4())[0:8] + ".png")
self._extracted_from_run_image_2(
"===============Running run_image =============", image, state
)
image_filename = os.path.join('image', f"{str(uuid.uuid4())[:8]}.png")
print("======>Auto Resize Image...")
img = Image.open(image.name)
width, height = img.size
Expand All @@ -927,14 +979,19 @@ def run_image(self, image, state, txt):
img.save(image_filename, "PNG")
print(f"Resize image form {width}x{height} to {width_new}x{height_new}")
description = self.i2t.inference(image_filename)
Human_prompt = "\nHuman: provide a figure named {}. The description is: {}. This information helps you to understand this image, but you should use tools to finish following tasks, " \
"rather than directly imagine from my description. If you understand, say \"Received\". \n".format(image_filename, description)
Human_prompt = f'\nHuman: provide a figure named {image_filename}. The description is: {description}. This information helps you to understand this image, but you should use tools to finish following tasks, rather than directly imagine from my description. If you understand, say \"Received\". \n'
AI_prompt = "Received. "
self.agent.memory.buffer = self.agent.memory.buffer + Human_prompt + 'AI: ' + AI_prompt
print("======>Current memory:\n %s" % self.agent.memory)
state = state + [(f"![](/file={image_filename})*{image_filename}*", AI_prompt)]
print("Outputs:", state)
return state, state, txt + ' ' + image_filename + ' '
return state, state, f'{txt} {image_filename} '

# TODO Rename this here and in `run_text` and `run_image`
def _extracted_from_run_image_2(self, arg0, arg1, state):
print(arg0)
print("Inputs:", arg1, state)
print("======>Previous memory:\n %s" % self.agent.memory)

if __name__ == '__main__':
bot = ConversationBot()
Expand Down

0 comments on commit a1412fb

Please sign in to comment.