forked from stanford-oval/genie-worksheets
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinterface_utils.py
105 lines (83 loc) · 3.28 KB
/
interface_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import json
from worksheets.annotation_utils import get_agent_action_schemas, get_context_schema
from worksheets.chat import generate_next_turn
from worksheets.modules.dialogue import CurrentDialogueTurn
def convert_to_json(dialogue: list[CurrentDialogueTurn]):
"""Convert the dialogue history to a JSON-compatible format.
Args:
dialogue (list[CurrentDialogueTurn]): The dialogue history.
Returns:
list[dict]: The dialogue history in JSON format.
"""
json_dialogue = []
for turn in dialogue:
json_turn = {
"user": turn.user_utterance,
"bot": turn.system_response,
"turn_context": get_context_schema(turn.context),
"global_context": get_context_schema(turn.global_context),
"system_action": get_agent_action_schemas(turn.system_action),
"user_target_sp": turn.user_target_sp,
"user_target": turn.user_target,
"user_target_suql": turn.user_target_suql,
}
json_dialogue.append(json_turn)
return json_dialogue
class bcolors:
HEADER = "\033[95m"
OKBLUE = "\033[94m"
OKCYAN = "\033[96m"
OKGREEN = "\033[92m"
WARNING = "\033[93m"
FAIL = "\033[91m"
ENDC = "\033[0m"
BOLD = "\033[1m"
UNDERLINE = "\033[4m"
def input_user() -> str:
"""Prompt the user for input and return the input string."""
try:
user_utterance = input(bcolors.OKCYAN + bcolors.BOLD + "User: ")
# ignore empty inputs
while not user_utterance.strip():
user_utterance = input(bcolors.OKCYAN + bcolors.BOLD + "User: ")
finally:
print(bcolors.ENDC)
return user_utterance
def print_chatbot(s: str):
"""Print the chatbot's response in a formatted way."""
print(bcolors.OKGREEN + bcolors.BOLD + "Agent: " + s + bcolors.ENDC)
def print_user(s: str):
"""Print the user's utterance in a formatted way."""
print(bcolors.OKCYAN + bcolors.BOLD + "User: " + s + bcolors.ENDC)
def print_complete_history(dialogue_history):
"""Print the complete dialogue history."""
for turn in dialogue_history:
print_user(turn.user_utterance)
print_chatbot(turn.system_response)
async def conversation_loop(bot, output_state_path, quit_commands=None):
"""Run the conversation loop with the chatbot. Dumps the dialogue history to a JSON file upon exit.
Args:
bot: The chatbot instance.
output_state_path (str): The path to save the dialogue history.
quit_commands (list[str], optional): List of commands to quit the conversation. Defaults to None.
"""
if quit_commands is None:
quit_commands = ["exit", "exit()"]
try:
while True:
if len(bot.dlg_history) == 0:
print_chatbot(bot.starting_prompt)
user_utterance = None
if user_utterance is None:
user_utterance = input_user()
if user_utterance == quit_commands:
break
await generate_next_turn(user_utterance, bot)
print_complete_history(bot.dlg_history)
except Exception as e:
print(e)
import traceback
traceback.print_exc()
finally:
with open(output_state_path, "w") as f:
json.dump(convert_to_json(bot.dlg_history), f, indent=4)