Skip to content

Commit

Permalink
Merge pull request #11 from thehunmonkgroup/chatbot-streaming
Browse files Browse the repository at this point in the history
Chatbot example enhancements
  • Loading branch information
Bam4d authored Dec 18, 2023
2 parents 91854c5 + b522bf0 commit bcbf938
Showing 1 changed file with 172 additions and 40 deletions.
212 changes: 172 additions & 40 deletions examples/chatbot_with_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import argparse
import logging
import os
import readline
import sys

from mistralai.client import MistralClient
Expand All @@ -16,49 +17,132 @@
"mistral-medium",
]
DEFAULT_MODEL = "mistral-small"
DEFAULT_TEMPERATURE = 0.7
LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s"
COMMAND_LIST = [
"/new",
"/help",
"/model",
"/system",
"/temperature",
"/config",
"/quit",
"/exit",
]

LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s"

logger = logging.getLogger("chatbot")


def completer(text, state):
buffer = readline.get_line_buffer()
if not buffer.startswith(text):
return None

options = [command for command in COMMAND_LIST if command.startswith(text)]
if state < len(options):
return options[state]
else:
return None


readline.set_completer(completer)
# Remove all delimiters to ensure completion only at the beginning of the line
readline.set_completer_delims("")
# Enable tab completion
readline.parse_and_bind("tab: complete")


class ChatBot:
def __init__(self, api_key, model, system_message=None):
def __init__(
self, api_key, model, system_message=None, temperature=DEFAULT_TEMPERATURE
):
self.client = MistralClient(api_key=api_key)
self.model = model
self.temperature = temperature
self.system_message = system_message

def opening_instructions(self):
print("""
print(
"""
To chat: type your message and hit enter
To start a new chat: type /new
To exit: type /exit, /quit, or hit CTRL+C
""")
To start a new chat: /new
To switch model: /model <model name>
To switch system message: /system <message>
To switch temperature: /temperature <temperature>
To see current config: /config
To exit: /exit, /quit, or hit CTRL+C
To see this help: /help
"""
)

def new_chat(self):
print("")
print(
f"Starting new chat with model: {self.model}, temperature: {self.temperature}"
)
print("")
self.messages = []
if self.system_message:
self.messages.append(ChatMessage(role="system", content=self.system_message))
self.messages.append(
ChatMessage(role="system", content=self.system_message)
)

def check_exit(self, content):
if content.lower().strip() in ["/exit", "/quit"]:
self.exit()
def switch_model(self, input):
model = self.get_arguments(input)
if model in MODEL_LIST:
self.model = model
logger.info(f"Switching model: {model}")
else:
logger.error(f"Invalid model name: {model}")

def check_new_chat(self, content):
if content.lower().strip() in ["/new"]:
print("")
print("Starting new chat...")
print("")
def switch_system_message(self, input):
system_message = self.get_arguments(input)
if system_message:
self.system_message = system_message
logger.info(f"Switching system message: {system_message}")
self.new_chat()
return True
return False
else:
logger.error(f"Invalid system message: {system_message}")

def switch_temperature(self, input):
temperature = self.get_arguments(input)
try:
temperature = float(temperature)
if temperature < 0 or temperature > 1:
raise ValueError
self.temperature = temperature
logger.info(f"Switching temperature: {temperature}")
except ValueError:
logger.error(f"Invalid temperature: {temperature}")

def show_config(self):
print("")
print(f"Current model: {self.model}")
print(f"Current temperature: {self.temperature}")
print(f"Current system message: {self.system_message}")
print("")

def collect_user_input(self):
print("")
return input("YOU: ")

def run_inference(self, content):
print("")
print("MISTRAL:")
print("")

self.messages.append(ChatMessage(role="user", content=content))

assistant_response = ""
logger.debug(
f"Running inference with model: {self.model}, temperature: {self.temperature}"
)
logger.debug(f"Sending messages: {self.messages}")
for chunk in self.client.chat_stream(model=self.model, messages=self.messages):
for chunk in self.client.chat_stream(
model=self.model, temperature=self.temperature, messages=self.messages
):
response = chunk.choices[0].delta.content
if response is not None:
print(response, end="", flush=True)
Expand All @@ -67,24 +151,50 @@ def run_inference(self, content):
print("", flush=True)

if assistant_response:
self.messages.append(ChatMessage(role="assistant", content=assistant_response))
self.messages.append(
ChatMessage(role="assistant", content=assistant_response)
)
logger.debug(f"Current messages: {self.messages}")

def start(self):
def get_command(self, input):
return input.split()[0].strip()

def get_arguments(self, input):
try:
return " ".join(input.split()[1:])
except IndexError:
return ""

def is_command(self, input):
return self.get_command(input) in COMMAND_LIST

def execute_command(self, input):
command = self.get_command(input)
if command in ["/exit", "/quit"]:
self.exit()
elif command == "/help":
self.opening_instructions()
elif command == "/new":
self.new_chat()
elif command == "/model":
self.switch_model(input)
elif command == "/system":
self.switch_system_message(input)
elif command == "/temperature":
self.switch_temperature(input)
elif command == "/config":
self.show_config()

def start(self):
self.opening_instructions()
self.new_chat()

while True:
try:
print("")
content = input("YOU: ")
self.check_exit(content)
if not self.check_new_chat(content):
print("")
print("MISTRAL:")
print("")
self.run_inference(content)
input = self.collect_user_input()
if self.is_command(input):
self.execute_command(input)
else:
self.run_inference(input)

except KeyboardInterrupt:
self.exit()
Expand All @@ -95,16 +205,34 @@ def exit(self):


if __name__ == "__main__":

parser = argparse.ArgumentParser(description="A simple chatbot using the Mistral API")
parser.add_argument("--api-key", default=os.environ.get("MISTRAL_API_KEY"),
help="Mistral API key. Defaults to environment variable MISTRAL_API_KEY")
parser.add_argument("-m", "--model", choices=MODEL_LIST,
default=DEFAULT_MODEL,
help="Model for chat inference. Choices are %(choices)s. Defaults to %(default)s")
parser.add_argument("-s", "--system-message",
help="Optional system message to prepend.")
parser.add_argument("-d", "--debug", action="store_true", help="Enable debug logging")
parser = argparse.ArgumentParser(
description="A simple chatbot using the Mistral API"
)
parser.add_argument(
"--api-key",
default=os.environ.get("MISTRAL_API_KEY"),
help="Mistral API key. Defaults to environment variable MISTRAL_API_KEY",
)
parser.add_argument(
"-m",
"--model",
choices=MODEL_LIST,
default=DEFAULT_MODEL,
help="Model for chat inference. Choices are %(choices)s. Defaults to %(default)s",
)
parser.add_argument(
"-s", "--system-message", help="Optional system message to prepend."
)
parser.add_argument(
"-t",
"--temperature",
type=float,
default=DEFAULT_TEMPERATURE,
help="Optional temperature for chat inference. Defaults to %(default)s",
)
parser.add_argument(
"-d", "--debug", action="store_true", help="Enable debug logging"
)

args = parser.parse_args()

Expand All @@ -119,7 +247,11 @@ def exit(self):
ch.setFormatter(formatter)
logger.addHandler(ch)

logger.debug(f"Starting chatbot with model: {args.model}")
logger.debug(
f"Starting chatbot with model: {args.model}, "
f"temperature: {args.temperature}, "
f"system message: {args.system_message}"
)

bot = ChatBot(args.api_key, args.model, args.system_message)
bot = ChatBot(args.api_key, args.model, args.system_message, args.temperature)
bot.start()

0 comments on commit bcbf938

Please sign in to comment.