Skip to content

Commit

Permalink
update readme; demo
Browse files Browse the repository at this point in the history
  • Loading branch information
yadong-lu committed Oct 9, 2024
1 parent 4a8758b commit 664407a
Show file tree
Hide file tree
Showing 6 changed files with 173 additions and 534 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
## Install
```python
conda create -n "omni" python==3.12
pip install -r requirements.txt
conda activate omni
pip install -r requirement.txt
```

## Examples:
Expand Down
Binary file modified __pycache__/utils.cpython-312.pyc
Binary file not shown.
470 changes: 160 additions & 310 deletions demo.ipynb

Large diffs are not rendered by default.

Binary file added util/__pycache__/__init__.cpython-312.pyc
Binary file not shown.
Binary file added util/__pycache__/box_annotator.cpython-312.pyc
Binary file not shown.
234 changes: 11 additions & 223 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# %matplotlib inline
from matplotlib import pyplot as plt
import easyocr
reader = easyocr.Reader(['en']) # this needs to run only once to load the model into memory # 'ch_sim',
reader = easyocr.Reader(['en'])
import time
import base64

Expand All @@ -33,44 +33,19 @@
import torchvision.transforms as T


def get_caption_model_processor(model_name="Salesforce/blip2-opt-2.7b", device=None):
def get_caption_model_processor(model_name_or_path="Salesforce/blip2-opt-2.7b", device=None):
if not device:
device = "cuda" if torch.cuda.is_available() else "cpu"
if model_name == "Salesforce/blip2-opt-2.7b":
from transformers import Blip2Processor, Blip2ForConditionalGeneration
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
from transformers import Blip2Processor, Blip2ForConditionalGeneration
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
if device == 'cpu':
model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-opt-2.7b", device_map=None, torch_dtype=torch.float16
# '/home/yadonglu/sandbox/data/orca/blipv2_ui_merge', device_map=None, torch_dtype=torch.float16
)
elif model_name == "blip2-opt-2.7b-ui":
from transformers import Blip2Processor, Blip2ForConditionalGeneration
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
if device == 'cpu':
model = Blip2ForConditionalGeneration.from_pretrained(
'/home/yadonglu/sandbox/data/orca/blipv2_ui_merge', device_map=None, torch_dtype=torch.float32
)
else:
model = Blip2ForConditionalGeneration.from_pretrained(
'/home/yadonglu/sandbox/data/orca/blipv2_ui_merge', device_map=None, torch_dtype=torch.float16
)
elif model_name == "florence":
from transformers import AutoProcessor, AutoModelForCausalLM
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
if device == 'cpu':
model = AutoModelForCausalLM.from_pretrained("/home/yadonglu/sandbox/data/orca/florence-2-base-ft-fft_ep1_rai", torch_dtype=torch.float32, trust_remote_code=True)#.to(device)
else:
model = AutoModelForCausalLM.from_pretrained("/home/yadonglu/sandbox/data/orca/florence-2-base-ft-fft_ep1_rai_win_ep5_fixed", torch_dtype=torch.float16, trust_remote_code=True).to(device)
elif model_name == 'phi3v_ui':
from transformers import AutoModelForCausalLM, AutoProcessor
model_id = "microsoft/Phi-3-vision-128k-instruct"
model = AutoModelForCausalLM.from_pretrained('/home/yadonglu/sandbox/data/orca/phi3v_ui', device_map=device, trust_remote_code=True, torch_dtype="auto")
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
elif model_name == 'phi3v':
from transformers import AutoModelForCausalLM, AutoProcessor
model_id = "microsoft/Phi-3-vision-128k-instruct"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, trust_remote_code=True, torch_dtype="auto")
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
model_name_or_path, device_map=None, torch_dtype=torch.float32
)
else:
model = Blip2ForConditionalGeneration.from_pretrained(
model_name_or_path, device_map=None, torch_dtype=torch.float16
)
return {'model': model.to(device), 'processor': processor}


Expand All @@ -94,14 +69,12 @@ def get_parsed_content_icon(filtered_boxes, ocr_bbox, image_source, caption_mode
cropped_image = image_source[ymin:ymax, xmin:xmax, :]
croped_pil_image.append(to_pil(cropped_image))

# import pdb; pdb.set_trace()
model, processor = caption_model_processor['model'], caption_model_processor['processor']
if not prompt:
if 'florence' in model.config.name_or_path:
prompt = "<CAPTION>"
else:
prompt = "The image shows"
# prompt = "NO gender!NO gender!NO gender! The image shows a icon:"

batch_size = 10 # Number of samples per batch
generated_texts = []
Expand Down Expand Up @@ -387,117 +360,15 @@ def get_xywh_yolo(input):
return x, y, w, h


def run_api(body, max_tokens=1024):
'''
API call, check https://platform.openai.com/docs/guides/vision for the latest api usage.
'''
max_num_trial = 3
num_trial = 0
while num_trial < max_num_trial:
try:
response = client.chat.completions.create(
model=deployment,
messages=body,
temperature=0.01,
max_tokens=max_tokens,
)
return response.choices[0].message.content
except:
print('retry call gptv', num_trial)
num_trial += 1
time.sleep(10)
return ''

def call_gpt4v_new(message_text, image_path=None, max_tokens=2048):
if image_path:
try:
with open(image_path, "rb") as img_file:
encoded_image = base64.b64encode(img_file.read()).decode('ascii')
except:
encoded_image = image_path

if image_path:
content = [{"type": "image_url","image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}}, {"type": "text","text": message_text},]
else:
content = [{"type": "text","text": message_text},]

max_num_trial = 3
num_trial = 0
call_api_success = True

while num_trial < max_num_trial:
try:
response = client.chat.completions.create(
model=deployment,
messages=[
{
"role": "system",
"content": [
{
"type": "text",
"text": "You are an AI assistant that is good at making plans and analyzing screens, and helping people find information."
},
]
},
{
"role": "user",
"content": content
}
],
temperature=0.01,
max_tokens=max_tokens,
)
ans_1st_pass = response.choices[0].message.content
break
except:
print('retry call gptv', num_trial)
num_trial += 1
ans_1st_pass = ''
time.sleep(10)
if num_trial == max_num_trial:
call_api_success = False
return ans_1st_pass, call_api_success


def check_ocr_box(image_path, display_img = True, output_bb_format='xywh', goal_filtering=None, easyocr_args=None):
if easyocr_args is None:
easyocr_args = {}
result = reader.readtext(image_path, **easyocr_args)
is_goal_filtered = False
if goal_filtering:
ocr_filter_fs = "Example 1:\n Based on task and ocr results, ```In summary, the task related bboxes are: [([[3060, 111], [3135, 111], [3135, 141], [3060, 141]], 'Share', 0.949013667261589), ([[3068, 197], [3135, 197], [3135, 227], [3068, 227]], 'Link _', 0.3567054243152049), ([[3006, 321], [3178, 321], [3178, 354], [3006, 354]], 'Manage Access', 0.8800734456437066)] ``` \n Example 2:\n Based on task and ocr results, ```In summary, the task related bboxes are: [([[3060, 111], [3135, 111], [3135, 141], [3060, 141]], 'Search Google or type a URL', 0.949013667261589)] ```"
# message_text = f"Based on the ocr results which contains text+bounding box in a dictionary, please filter it so that it only contains the task related bboxes. The task is: {goal_filtering}, the ocr results are: {str(result)}. Your final answer should be in the exact same format as the ocr results, please do not include any other redundant information, please do not include any analysis."
message_text = f"Based on the task and ocr results which contains text+bounding box in a dictionary, please filter it so that it only contains the task related bboxes. Requirement: 1. first give a brief analysis. 2. provide an answer in the format: ```In summary, the task related bboxes are: ..```, you must put it inside ``` ```. Do not include any info after ```.\n {ocr_filter_fs}\n The task is: {goal_filtering}, the ocr results are: {str(result)}."

prompt = [{"role":"system", "content": "You are an AI assistant that helps people find the correct way to operate computer or smartphone."}, {"role":"user","content": message_text},]
print('[Perform OCR filtering by goal] ongoing ...')
# pred, _, _ = call_gpt4(prompt)
pred, _, = call_gpt4v(message_text)
# import pdb; pdb.set_trace()
try:
# match = re.search(r"```(.*?)```", pred, re.DOTALL)
# result = match.group(1).strip()
# pred = result.split('In summary, the task related bboxes are:')[-1].strip()
pred = pred.split('In summary, the task related bboxes are:')[-1].strip().strip('```')
result = ast.literal_eval(pred)
print('[Perform OCR filtering by goal] success!!! Filtered buttons: ', pred)
is_goal_filtered = True
except:
print('[Perform OCR filtering by goal] failed or unused!!!')
pass
# added_prompt = [{"role":"assistant","content":pred},
# {"role":"user","content": "given the previous answers, please provide the final answer in the exact same format as the ocr results, please do not include any other redundant information, please do not include any analysis."}]
# prompt.extend(added_prompt)
# pred, _, _ = call_gpt4(prompt)
# print('goal filtering pred 2nd:', pred)
# result = ast.literal_eval(pred)
# print('goal filtering pred:', result[-5:])
coord = [item[0] for item in result]
text = [item[1] for item in result]
# confidence = [item[2] for item in result]
# if confidence_filtering:
# coord = [coord[i] for i in range(len(coord)) if confidence[i] > confidence_filtering]
# text = [text[i] for i in range(len(text)) if confidence[i] > confidence_filtering]
# read the image using cv2
if display_img:
opencv_img = cv2.imread(image_path)
Expand All @@ -520,87 +391,4 @@ def check_ocr_box(image_path, display_img = True, output_bb_format='xywh', goal_
return (text, bb), is_goal_filtered


def get_pred_gptv(message_text, yolo_labled_img, label_coordinates, summarize_history=True, verbose=True, history=None, id_key='Click ID'):
""" This func first
1. call gptv(yolo_labled_img, text bbox+task) -> ans_1st_cal
2. call gpt4(ans_1st_cal, label_coordinates) -> final ans
"""

# Configuration
encoded_image = yolo_labled_img

# Payload for the request
if not history:
messages = [
{"role": "system", "content": [{"type": "text","text": "You are an AI assistant that is great at interpreting screenshot and predict action."},]},
{"role": "user","content": [{"type": "text","text": message_text}, {"type": "image_url","image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}},]}
]
else:
messages = [
{"role": "system", "content": [{"type": "text","text": "You are an AI assistant that is great at interpreting screenshot and predict action."},]},
history,
{"role": "user","content": [{"type": "image_url","image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}},{"type": "text","text": message_text},]}
]

payload = {
"messages": messages,
"temperature": 0.01, # 0.01
"top_p": 0.95,
"max_tokens": 800
}

max_num_trial = 3
num_trial = 0
call_api_success = True
while num_trial < max_num_trial:
try:
# response = requests.post(GPT4V_ENDPOINT, headers=headers, json=payload)
# response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code
# ans_1st_pass = response.json()['choices'][0]['message']['content']
response = client.chat.completions.create(
model=deployment,
messages=messages,
temperature=0.01,
max_tokens=512,
)
ans_1st_pass = response.choices[0].message.content
break
except requests.RequestException as e:
print('retry call gptv', num_trial)
num_trial += 1
ans_1st_pass = ''
time.sleep(30)
# raise SystemExit(f"Failed to make the request. Error: {e}")
if num_trial == max_num_trial:
call_api_success = False
if verbose:
print('Answer by GPTV: ', ans_1st_pass)
# extract by simple parsing
try:
match = re.search(r"```(.*?)```", ans_1st_pass, re.DOTALL)
if match:
result = match.group(1).strip()
pred = result.split('In summary, the next action I will perform is:')[-1].strip().replace('\\', '')
pred = ast.literal_eval(pred)
else:
pred = ans_1st_pass.split('In summary, the next action I will perform is:')[-1].strip().replace('\\', '')
pred = ast.literal_eval(pred)

if id_key in pred:
icon_id = pred[id_key]
bbox = label_coordinates[str(icon_id)]
pred['click_point'] = [bbox[0] + bbox[2]/2, bbox[1] + bbox[3]/2]
except:
# import pdb; pdb.set_trace()
print('gptv action regex extract fail!!!')
print('ans_1st_pass:', ans_1st_pass)
pred = {'action_type': 'CLICK', 'click_point': [0, 0], 'value': 'None', 'is_completed': False}

step_pred_summary = None
if summarize_history:
step_pred_summary, _ = call_gpt4v_new('Summarize what action you decide to perform in the current step, in one sentence, and do not include any icon box number: ' + ans_1st_pass, max_tokens=128)
print('step_pred_summary', step_pred_summary)
return pred, [call_api_success, ans_1st_pass, None, step_pred_summary]
# return pred, [call_api_success, message_2nd, completion_2nd.choices[0].message.content, step_pred_summary]


0 comments on commit 664407a

Please sign in to comment.