-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathmain.py
138 lines (116 loc) · 3.95 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import asyncio
import json
import logging
import click
import requests
from dotenv import load_dotenv
from hugginggpt import generate_response, infer, plan_tasks
from hugginggpt.history import ConversationHistory
from hugginggpt.llm_factory import LLMs, create_llms
from hugginggpt.log import setup_logging
from hugginggpt.model_inference import TaskSummary
from hugginggpt.model_selection import select_hf_models
from hugginggpt.response_generation import format_response
load_dotenv()
setup_logging()
logger = logging.getLogger(__name__)
@click.command()
@click.option("-p", "--prompt", type=str, help="Prompt for huggingGPT")
def main(prompt):
_print_banner()
llms = create_llms()
if prompt:
standalone_mode(user_input=prompt, llms=llms)
else:
interactive_mode(llms=llms)
def standalone_mode(user_input: str, llms: LLMs) -> str:
try:
response, task_summaries = compute(
user_input=user_input,
history=ConversationHistory(),
llms=llms,
)
pretty_response = format_response(response)
print(pretty_response)
return pretty_response
except Exception as e:
logger.exception("")
print(
f"Sorry, encountered error: {e}. Please try again. Check logs if problem persists."
)
def interactive_mode(llms: LLMs):
print("Please enter your request. End the conversation with 'exit'")
history = ConversationHistory()
while True:
try:
user_input = click.prompt("User")
if user_input.lower() == "exit":
break
logger.info(f"User input: {user_input}")
response, task_summaries = compute(
user_input=user_input,
history=history,
llms=llms,
)
pretty_response = format_response(response)
print(f"Assistant:{pretty_response}")
history.add(role="user", content=user_input)
history.add(role="assistant", content=response)
except Exception as e:
logger.exception("")
print(
f"Sorry, encountered error: {e}. Please try again. Check logs if problem persists."
)
def compute(
user_input: str,
history: ConversationHistory,
llms: LLMs,
) -> (str, list[TaskSummary]):
tasks = plan_tasks(
user_input=user_input, history=history, llm=llms.task_planning_llm
)
sorted(tasks, key=lambda t: max(t.dep))
logger.info(f"Sorted tasks: {tasks}")
hf_models = asyncio.run(
select_hf_models(
user_input=user_input,
tasks=tasks,
model_selection_llm=llms.model_selection_llm,
output_fixing_llm=llms.output_fixing_llm,
)
)
task_summaries = []
with requests.Session() as session:
for task in tasks:
logger.info(f"Starting task: {task}")
if task.depends_on_generated_resources():
task = task.replace_generated_resources(task_summaries=task_summaries)
model = hf_models[task.id]
inference_result = infer(
task=task,
model_id=model.id,
llm=llms.model_inference_llm,
session=session,
)
task_summaries.append(
TaskSummary(
task=task,
model=model,
inference_result=json.dumps(inference_result),
)
)
logger.info(f"Finished task: {task}")
logger.info("Finished all tasks")
logger.debug(f"Task summaries: {task_summaries}")
response = generate_response(
user_input=user_input,
task_summaries=task_summaries,
llm=llms.response_generation_llm,
)
return response, task_summaries
def _print_banner():
with open("resources/banner.txt", "r") as f:
banner = f.read()
logger.info("\n" + banner)
if __name__ == "__main__":
main()