Skip to content

Commit

Permalink
formatting fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
thehunmonkgroup committed Dec 13, 2023
1 parent e39119a commit c66faa7
Showing 1 changed file with 56 additions and 24 deletions.
80 changes: 56 additions & 24 deletions examples/chatbot_with_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
]
DEFAULT_MODEL = "mistral-small"
DEFAULT_TEMPERATURE = 0.7
LOG_FORMAT = '%(asctime)s - %(levelname)s - %(message)s'
LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s"
COMMAND_LIST = [
"/new",
"/help",
Expand Down Expand Up @@ -49,20 +49,23 @@ def completer(text, state):

readline.set_completer(completer)
# Remove all delimiters to ensure completion only at the beginning of the line
readline.set_completer_delims('')
readline.set_completer_delims("")
# Enable tab completion
readline.parse_and_bind('tab: complete')
readline.parse_and_bind("tab: complete")


class ChatBot:
def __init__(self, api_key, model, system_message=None, temperature=DEFAULT_TEMPERATURE):
def __init__(
self, api_key, model, system_message=None, temperature=DEFAULT_TEMPERATURE
):
self.client = MistralClient(api_key=api_key)
self.model = model
self.temperature = temperature
self.system_message = system_message

def opening_instructions(self):
print("""
print(
"""
To chat: type your message and hit enter
To start a new chat: /new
To switch model: /model <model name>
Expand All @@ -71,15 +74,20 @@ def opening_instructions(self):
To see current config: /config
To exit: /exit, /quit, or hit CTRL+C
To see this help: /help
""")
"""
)

def new_chat(self):
print("")
print(f"Starting new chat with model: {self.model}, temperature: {self.temperature}")
print(
f"Starting new chat with model: {self.model}, temperature: {self.temperature}"
)
print("")
self.messages = []
if self.system_message:
self.messages.append(ChatMessage(role="system", content=self.system_message))
self.messages.append(
ChatMessage(role="system", content=self.system_message)
)

def switch_model(self, input):
model = self.get_arguments(input)
Expand Down Expand Up @@ -128,9 +136,13 @@ def run_inference(self, content):
self.messages.append(ChatMessage(role="user", content=content))

assistant_response = ""
logger.debug(f"Running inference with model: {self.model}, temperature: {self.temperature}")
logger.debug(
f"Running inference with model: {self.model}, temperature: {self.temperature}"
)
logger.debug(f"Sending messages: {self.messages}")
for chunk in self.client.chat_stream(model=self.model, temperature=self.temperature, messages=self.messages):
for chunk in self.client.chat_stream(
model=self.model, temperature=self.temperature, messages=self.messages
):
response = chunk.choices[0].delta.content
if response is not None:
print(response, end="", flush=True)
Expand All @@ -139,7 +151,9 @@ def run_inference(self, content):
print("", flush=True)

if assistant_response:
self.messages.append(ChatMessage(role="assistant", content=assistant_response))
self.messages.append(
ChatMessage(role="assistant", content=assistant_response)
)
logger.debug(f"Current messages: {self.messages}")

def get_command(self, input):
Expand Down Expand Up @@ -191,18 +205,34 @@ def exit(self):


if __name__ == "__main__":

parser = argparse.ArgumentParser(description="A simple chatbot using the Mistral API")
parser.add_argument("--api-key", default=os.environ.get("MISTRAL_API_KEY"),
help="Mistral API key. Defaults to environment variable MISTRAL_API_KEY")
parser.add_argument("-m", "--model", choices=MODEL_LIST,
default=DEFAULT_MODEL,
help="Model for chat inference. Choices are %(choices)s. Defaults to %(default)s")
parser.add_argument("-s", "--system-message",
help="Optional system message to prepend.")
parser.add_argument("-t", "--temperature", type=float, default=DEFAULT_TEMPERATURE,
help="Optional temperature for chat inference. Defaults to %(default)s")
parser.add_argument("-d", "--debug", action="store_true", help="Enable debug logging")
parser = argparse.ArgumentParser(
description="A simple chatbot using the Mistral API"
)
parser.add_argument(
"--api-key",
default=os.environ.get("MISTRAL_API_KEY"),
help="Mistral API key. Defaults to environment variable MISTRAL_API_KEY",
)
parser.add_argument(
"-m",
"--model",
choices=MODEL_LIST,
default=DEFAULT_MODEL,
help="Model for chat inference. Choices are %(choices)s. Defaults to %(default)s",
)
parser.add_argument(
"-s", "--system-message", help="Optional system message to prepend."
)
parser.add_argument(
"-t",
"--temperature",
type=float,
default=DEFAULT_TEMPERATURE,
help="Optional temperature for chat inference. Defaults to %(default)s",
)
parser.add_argument(
"-d", "--debug", action="store_true", help="Enable debug logging"
)

args = parser.parse_args()

Expand All @@ -217,7 +247,9 @@ def exit(self):
ch.setFormatter(formatter)
logger.addHandler(ch)

logger.debug(f"Starting chatbot with model: {args.model}, temperature: {args.temperature}, system message: {args.system_message}")
logger.debug(
f"Starting chatbot with model: {args.model}, temperature: {args.temperature}, system message: {args.system_message}"
)

bot = ChatBot(args.api_key, args.model, args.system_message, args.temperature)
bot.start()

0 comments on commit c66faa7

Please sign in to comment.