Skip to content

Commit

Permalink
Merge pull request chenfei-wu#103 from BinayakJha/main
Browse files Browse the repository at this point in the history
Optimized the visual-chatgpt.py Code Which makes it a good clean code
  • Loading branch information
chenfei-wu authored Mar 15, 2023
2 parents 807fcb8 + 149dea7 commit dfad8fb
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 53 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ timm
torchmetrics
transformers
webdataset
yapf
yapf
97 changes: 45 additions & 52 deletions visual_chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,36 +93,33 @@ def cut_dialogue_history(history_memory, keep_last_n_words=500):
print(f"hitory_memory:{history_memory}, n_tokens: {n_tokens}")
if n_tokens < keep_last_n_words:
return history_memory
else:
paragraphs = history_memory.split('\n')
last_n_tokens = n_tokens
while last_n_tokens >= keep_last_n_words:
last_n_tokens = last_n_tokens - len(paragraphs[0].split(' '))
paragraphs = paragraphs[1:]
return '\n' + '\n'.join(paragraphs)
paragraphs = history_memory.split('\n')
last_n_tokens = n_tokens
while last_n_tokens >= keep_last_n_words:
last_n_tokens -= len(paragraphs[0].split(' '))
paragraphs = paragraphs[1:]
return '\n' + '\n'.join(paragraphs)


def get_new_image_name(org_img_name, func_name="update"):
head_tail = os.path.split(org_img_name)
head = head_tail[0]
tail = head_tail[1]
name_split = tail.split('.')[0].split('_')
this_new_uuid = str(uuid.uuid4())[0:4]
this_new_uuid = str(uuid.uuid4())[:4]
if len(name_split) == 1:
most_org_file_name = name_split[0]
recent_prev_file_name = name_split[0]
new_file_name = '{}_{}_{}_{}.png'.format(this_new_uuid, func_name, recent_prev_file_name, most_org_file_name)
else:
assert len(name_split) == 4
most_org_file_name = name_split[3]
recent_prev_file_name = name_split[0]
new_file_name = '{}_{}_{}_{}.png'.format(this_new_uuid, func_name, recent_prev_file_name, most_org_file_name)
recent_prev_file_name = name_split[0]
new_file_name = f'{this_new_uuid}_{func_name}_{recent_prev_file_name}_{most_org_file_name}.png'
return os.path.join(head, new_file_name)


class MaskFormer:
def __init__(self, device):
print("Initializing MaskFormer to %s" % device)
print(f"Initializing MaskFormer to {device}")
self.device = device
self.processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
self.model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(device)
Expand Down Expand Up @@ -152,7 +149,7 @@ def inference(self, image_path, text):

class ImageEditing:
def __init__(self, device):
print("Initializing ImageEditing to %s" % device)
print(f"Initializing ImageEditing to {device}")
self.device = device
self.mask_former = MaskFormer(device=self.device)
self.revision = 'fp16' if 'cuda' in device else None
Expand Down Expand Up @@ -192,7 +189,7 @@ def inference_replace(self, inputs):

class InstructPix2Pix:
def __init__(self, device):
print("Initializing InstructPix2Pix to %s" % device)
print(f"Initializing InstructPix2Pix to {device}")
self.device = device
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
self.pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained("timbrooks/instruct-pix2pix",
Expand Down Expand Up @@ -220,7 +217,7 @@ def inference(self, inputs):

class Text2Image:
def __init__(self, device):
print("Initializing Text2Image to %s" % device)
print(f"Initializing Text2Image to {device}")
self.device = device
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
self.pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5",
Expand All @@ -235,7 +232,7 @@ def __init__(self, device):
"like: generate an image of an object or something, or generate an image that includes some objects. "
"The input to this tool should be a string, representing the text used to generate image. ")
def inference(self, text):
image_filename = os.path.join('image', str(uuid.uuid4())[0:8] + ".png")
image_filename = os.path.join('image', f"{str(uuid.uuid4())[:8]}.png")
prompt = text + ', ' + self.a_prompt
image = self.pipe(prompt, negative_prompt=self.n_prompt).images[0]
image.save(image_filename)
Expand All @@ -246,7 +243,7 @@ def inference(self, text):

class ImageCaptioning:
def __init__(self, device):
print("Initializing ImageCaptioning to %s" % device)
print(f"Initializing ImageCaptioning to {device}")
self.device = device
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
Expand Down Expand Up @@ -290,7 +287,7 @@ def inference(self, inputs):

class CannyText2Image:
def __init__(self, device):
print("Initializing CannyText2Image to %s" % device)
print(f"Initializing CannyText2Image to {device}")
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-canny",
torch_dtype=self.torch_dtype)
Expand All @@ -302,7 +299,7 @@ def __init__(self, device):
self.seed = -1
self.a_prompt = 'best quality, extremely detailed'
self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
'fewer digits, cropped, worst quality, low quality'
'fewer digits, cropped, worst quality, low quality'

@prompts(name="Generate Image Condition On Canny Image",
description="useful when you want to generate a new real image from both the user desciption and a canny image."
Expand All @@ -315,7 +312,7 @@ def inference(self, inputs):
image = Image.open(image_path)
self.seed = random.randint(0, 65535)
seed_everything(self.seed)
prompt = instruct_text + ', ' + self.a_prompt
prompt = f'{instruct_text}, {self.a_prompt}'
image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
guidance_scale=9.0).images[0]
updated_image_path = get_new_image_name(image_path, func_name="canny2image")
Expand Down Expand Up @@ -346,7 +343,7 @@ def inference(self, inputs):

class LineText2Image:
def __init__(self, device):
print("Initializing LineText2Image to %s" % device)
print(f"Initializing LineText2Image to {device}")
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-mlsd",
torch_dtype=self.torch_dtype)
Expand All @@ -359,7 +356,7 @@ def __init__(self, device):
self.seed = -1
self.a_prompt = 'best quality, extremely detailed'
self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
'fewer digits, cropped, worst quality, low quality'
'fewer digits, cropped, worst quality, low quality'

@prompts(name="Generate Image Condition On Line Image",
description="useful when you want to generate a new real image from both the user desciption "
Expand All @@ -373,7 +370,7 @@ def inference(self, inputs):
image = Image.open(image_path)
self.seed = random.randint(0, 65535)
seed_everything(self.seed)
prompt = instruct_text + ', ' + self.a_prompt
prompt = f'{instruct_text}, {self.a_prompt}'
image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
guidance_scale=9.0).images[0]
updated_image_path = get_new_image_name(image_path, func_name="line2image")
Expand Down Expand Up @@ -404,7 +401,7 @@ def inference(self, inputs):

class HedText2Image:
def __init__(self, device):
print("Initializing HedText2Image to %s" % device)
print(f"Initializing HedText2Image to {device}")
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-hed",
torch_dtype=self.torch_dtype)
Expand All @@ -417,7 +414,7 @@ def __init__(self, device):
self.seed = -1
self.a_prompt = 'best quality, extremely detailed'
self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
'fewer digits, cropped, worst quality, low quality'
'fewer digits, cropped, worst quality, low quality'

@prompts(name="Generate Image Condition On Soft Hed Boundary Image",
description="useful when you want to generate a new real image from both the user desciption "
Expand All @@ -431,7 +428,7 @@ def inference(self, inputs):
image = Image.open(image_path)
self.seed = random.randint(0, 65535)
seed_everything(self.seed)
prompt = instruct_text + ', ' + self.a_prompt
prompt = f'{instruct_text}, {self.a_prompt}'
image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
guidance_scale=9.0).images[0]
updated_image_path = get_new_image_name(image_path, func_name="hed2image")
Expand Down Expand Up @@ -462,7 +459,7 @@ def inference(self, inputs):

class ScribbleText2Image:
def __init__(self, device):
print("Initializing ScribbleText2Image to %s" % device)
print(f"Initializing ScribbleText2Image to {device}")
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-scribble",
torch_dtype=self.torch_dtype)
Expand All @@ -475,7 +472,7 @@ def __init__(self, device):
self.seed = -1
self.a_prompt = 'best quality, extremely detailed'
self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
'fewer digits, cropped, worst quality, low quality'
'fewer digits, cropped, worst quality, low quality'

@prompts(name="Generate Image Condition On Sketch Image",
description="useful when you want to generate a new real image from both the user desciption and "
Expand All @@ -487,7 +484,7 @@ def inference(self, inputs):
image = Image.open(image_path)
self.seed = random.randint(0, 65535)
seed_everything(self.seed)
prompt = instruct_text + ', ' + self.a_prompt
prompt = f'{instruct_text}, {self.a_prompt}'
image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
guidance_scale=9.0).images[0]
updated_image_path = get_new_image_name(image_path, func_name="scribble2image")
Expand Down Expand Up @@ -517,7 +514,7 @@ def inference(self, inputs):

class PoseText2Image:
def __init__(self, device):
print("Initializing PoseText2Image to %s" % device)
print(f"Initializing PoseText2Image to {device}")
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-openpose",
torch_dtype=self.torch_dtype)
Expand All @@ -531,7 +528,7 @@ def __init__(self, device):
self.unconditional_guidance_scale = 9.0
self.a_prompt = 'best quality, extremely detailed'
self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit,' \
' fewer digits, cropped, worst quality, low quality'
' fewer digits, cropped, worst quality, low quality'

@prompts(name="Generate Image Condition On Pose Image",
description="useful when you want to generate a new real image from both the user desciption "
Expand All @@ -545,7 +542,7 @@ def inference(self, inputs):
image = Image.open(image_path)
self.seed = random.randint(0, 65535)
seed_everything(self.seed)
prompt = instruct_text + ', ' + self.a_prompt
prompt = f'{instruct_text}, {self.a_prompt}'
image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
guidance_scale=9.0).images[0]
updated_image_path = get_new_image_name(image_path, func_name="pose2image")
Expand Down Expand Up @@ -624,7 +621,7 @@ def inference(self, inputs):

class SegText2Image:
def __init__(self, device):
print("Initializing SegText2Image to %s" % device)
print(f"Initializing SegText2Image to {device}")
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-seg",
torch_dtype=self.torch_dtype)
Expand All @@ -636,7 +633,7 @@ def __init__(self, device):
self.seed = -1
self.a_prompt = 'best quality, extremely detailed'
self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit,' \
' fewer digits, cropped, worst quality, low quality'
' fewer digits, cropped, worst quality, low quality'

@prompts(name="Generate Image Condition On Segmentations",
description="useful when you want to generate a new real image from both the user desciption and segmentations. "
Expand All @@ -649,7 +646,7 @@ def inference(self, inputs):
image = Image.open(image_path)
self.seed = random.randint(0, 65535)
seed_everything(self.seed)
prompt = instruct_text + ', ' + self.a_prompt
prompt = f'{instruct_text}, {self.a_prompt}'
image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
guidance_scale=9.0).images[0]
updated_image_path = get_new_image_name(image_path, func_name="segment2image")
Expand Down Expand Up @@ -683,7 +680,7 @@ def inference(self, inputs):

class DepthText2Image:
def __init__(self, device):
print("Initializing DepthText2Image to %s" % device)
print(f"Initializing DepthText2Image to {device}")
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
self.controlnet = ControlNetModel.from_pretrained(
"fusing/stable-diffusion-v1-5-controlnet-depth", torch_dtype=self.torch_dtype)
Expand All @@ -695,7 +692,7 @@ def __init__(self, device):
self.seed = -1
self.a_prompt = 'best quality, extremely detailed'
self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit,' \
' fewer digits, cropped, worst quality, low quality'
' fewer digits, cropped, worst quality, low quality'

@prompts(name="Generate Image Condition On Depth",
description="useful when you want to generate a new real image from both the user desciption and depth image. "
Expand All @@ -708,7 +705,7 @@ def inference(self, inputs):
image = Image.open(image_path)
self.seed = random.randint(0, 65535)
seed_everything(self.seed)
prompt = instruct_text + ', ' + self.a_prompt
prompt = f'{instruct_text}, {self.a_prompt}'
image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
guidance_scale=9.0).images[0]
updated_image_path = get_new_image_name(image_path, func_name="depth2image")
Expand Down Expand Up @@ -754,7 +751,7 @@ def inference(self, inputs):

class NormalText2Image:
def __init__(self, device):
print("Initializing NormalText2Image to %s" % device)
print(f"Initializing NormalText2Image to {device}")
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
self.controlnet = ControlNetModel.from_pretrained(
"fusing/stable-diffusion-v1-5-controlnet-normal", torch_dtype=self.torch_dtype)
Expand All @@ -766,7 +763,7 @@ def __init__(self, device):
self.seed = -1
self.a_prompt = 'best quality, extremely detailed'
self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit,' \
' fewer digits, cropped, worst quality, low quality'
' fewer digits, cropped, worst quality, low quality'

@prompts(name="Generate Image Condition On Normal Map",
description="useful when you want to generate a new real image from both the user desciption and normal map. "
Expand All @@ -779,7 +776,7 @@ def inference(self, inputs):
image = Image.open(image_path)
self.seed = random.randint(0, 65535)
seed_everything(self.seed)
prompt = instruct_text + ', ' + self.a_prompt
prompt = f'{instruct_text}, {self.a_prompt}'
image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
guidance_scale=9.0).images[0]
updated_image_path = get_new_image_name(image_path, func_name="normal2image")
Expand All @@ -791,7 +788,7 @@ def inference(self, inputs):

class VisualQuestionAnswering:
def __init__(self, device):
print("Initializing VisualQuestionAnswering to %s" % device)
print(f"Initializing VisualQuestionAnswering to {device}")
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
self.device = device
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
Expand Down Expand Up @@ -823,12 +820,12 @@ def __init__(self, load_dict):
self.llm = OpenAI(temperature=0)
self.memory = ConversationBufferMemory(memory_key="chat_history", output_key='output')

self.models = dict()
self.models = {}
for class_name, device in load_dict.items():
self.models[class_name] = globals()[class_name](device=device)

self.tools = []
for class_name, instance in self.models.items():
for instance in self.models.values():
for e in dir(instance):
if e.startswith('inference'):
func = getattr(instance, e)
Expand All @@ -855,7 +852,7 @@ def run_text(self, text, state):
return state, state

def run_image(self, image, state, txt):
image_filename = os.path.join('image', str(uuid.uuid4())[0:8] + ".png")
image_filename = os.path.join('image', f"{str(uuid.uuid4())[:8]}.png")
print("======>Auto Resize Image...")
img = Image.open(image.name)
width, height = img.size
Expand All @@ -868,17 +865,13 @@ def run_image(self, image, state, txt):
img.save(image_filename, "PNG")
print(f"Resize image form {width}x{height} to {width_new}x{height_new}")
description = self.models['ImageCaptioning'].inference(image_filename)
Human_prompt = "\nHuman: provide a figure named {}. The description is: {}. " \
"This information helps you to understand this image, " \
"but you should use tools to finish following tasks, " \
"rather than directly imagine from my description. If you understand, say \"Received\". \n".format(
image_filename, description)
Human_prompt = f'\nHuman: provide a figure named {image_filename}. The description is: {description}. This information helps you to understand this image, but you should use tools to finish following tasks, rather than directly imagine from my description. If you understand, say \"Received\". \n'
AI_prompt = "Received. "
self.agent.memory.buffer = self.agent.memory.buffer + Human_prompt + 'AI: ' + AI_prompt
state = state + [(f"![](/file={image_filename})*{image_filename}*", AI_prompt)]
print(f"\nProcessed run_image, Input image: {image_filename}\nCurrent state: {state}\n"
f"Current Memory: {self.agent.memory.buffer}")
return state, state, txt + ' ' + image_filename + ' '
return state, state, f'{txt} {image_filename} '


if __name__ == '__main__':
Expand Down

0 comments on commit dfad8fb

Please sign in to comment.