Skip to content

Commit

Permalink
Merge pull request #9 from thehunmonkgroup/chatbot-streaming
Browse files Browse the repository at this point in the history
Chatbot example
  • Loading branch information
Bam4d authored Dec 13, 2023
2 parents 6bb416c + dabcd7d commit cf7b0e3
Showing 1 changed file with 125 additions and 0 deletions.
125 changes: 125 additions & 0 deletions examples/chatbot_with_streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
#!/usr/bin/env python

# Simple chatbot example -- run with -h argument to see options.

import os
import sys
import argparse
import logging

from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage

MODEL_LIST = [
"mistral-tiny",
"mistral-small",
"mistral-medium",
]
DEFAULT_MODEL = "mistral-small"

LOG_FORMAT = '%(asctime)s - %(levelname)s - %(message)s'

logger = logging.getLogger('chatbot')


class ChatBot:
def __init__(self, api_key, model, system_message=None):
self.client = MistralClient(api_key=api_key)
self.model = model
self.system_message = system_message

def opening_instructions(self):
print("""
To chat: type your message and hit enter
To start a new chat: type /new
To exit: type /exit, /quit, or hit CTRL+C
""")

def new_chat(self):
self.messages = []
if self.system_message:
self.messages.append(ChatMessage(role="system", content=self.system_message))

def check_exit(self, content):
if content.lower().strip() in ["/exit", "/quit"]:
self.exit()

def check_new_chat(self, content):
if content.lower().strip() in ["/new"]:
print("")
print("Starting new chat...")
print("")
self.new_chat()
return True
return False

def run_inference(self, content):
self.messages.append(ChatMessage(role="user", content=content))

assistant_response = ""
logger.debug(f"Sending messages: {self.messages}")
for chunk in self.client.chat_stream(model=self.model, messages=self.messages):
response = chunk.choices[0].delta.content
if response is not None:
print(response, end="", flush=True)
assistant_response += response

print("", flush=True)

if assistant_response:
self.messages.append(ChatMessage(role="assistant", content=assistant_response))
logger.debug(f"Current messages: {self.messages}")

def start(self):

self.opening_instructions()
self.new_chat()

while True:
try:
print("")
content = input("YOU: ")
self.check_exit(content)
if not self.check_new_chat(content):
print("")
print("MISTRAL:")
print("")
self.run_inference(content)

except KeyboardInterrupt:
self.exit()

def exit(self):
logger.debug("Exiting chatbot")
sys.exit(0)


if __name__ == "__main__":

parser = argparse.ArgumentParser(description="A simple chatbot using the Mistral API")
parser.add_argument("--api-key", default=os.environ.get("MISTRAL_API_KEY"),
help="Mistral API key. Defaults to environment variable MISTRAL_API_KEY")
parser.add_argument("-m", "--model", choices=MODEL_LIST,
default=DEFAULT_MODEL,
help="Model for chat inference. Choices are %(choices)s. Defaults to %(default)s")
parser.add_argument("-s", "--system-message",
help="Optional system message to prepend.")
parser.add_argument("-d", "--debug", action="store_true", help="Enable debug logging")

args = parser.parse_args()

if args.debug:
logger.setLevel(logging.DEBUG)
else:
logger.setLevel(logging.INFO)

formatter = logging.Formatter(LOG_FORMAT)

ch = logging.StreamHandler()
ch.setFormatter(formatter)
logger.addHandler(ch)

logger.debug(f"Starting chatbot with model: {args.model}")

bot = ChatBot(args.api_key, args.model, args.system_message)
bot.start()

0 comments on commit cf7b0e3

Please sign in to comment.