Skip to content

Commit

Permalink
BREAKING CHANGE: Updated how memories are generated, updated default …
Browse files Browse the repository at this point in the history
…conversational flow to not include long-term memories (faster and more stable responses), removed ModelConfigs entirely and consolidated everything into Endpoints, updated how streaming, truncate length and max tokens are handled so that we can once again go back to simplified presets and have the max token length on the individual nodes.
  • Loading branch information
SomeOddCodeGuy committed Jul 7, 2024
1 parent de74193 commit 0307b07
Show file tree
Hide file tree
Showing 174 changed files with 819 additions and 1,597 deletions.
39 changes: 26 additions & 13 deletions Middleware/llmapis/open_ai_llm_chat_completions_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,42 +8,56 @@

from Middleware.models.open_ai_api_presets import OpenAiApiPresets
from Middleware.utilities.config_utils import get_openai_preset_path, get_endpoint_config, \
get_is_chat_complete_add_user_assistant, get_is_chat_complete_add_missing_assistant
get_is_chat_complete_add_user_assistant, get_is_chat_complete_add_missing_assistant, get_config_property_if_exists


class OpenAiLlmChatCompletionsApiService:
"""
A service class that provides compatibility with OpenAI's API for interacting with LLMs.
"""

def __init__(self, endpoint: str, model_name: str, presetname: str, stream: bool = False):
def __init__(self, endpoint: str, presetname: str, api_type_config, truncate_length, max_tokens,
stream: bool = False):
"""
Initializes the OpenAiLlmChatCompletionsApiService with the given configuration.
:param endpoint: The API endpoint URL for the LLM service.
:param model_name: The model name to be used with the LLM service.
:param presetname: The name of the preset file containing API parameters.
:param stream: A boolean indicating whether to use streaming or not.
:param api_type_config: The config file for the specified apiType in the Endpoint
:param truncate_length: The max context length of the model, if it applies
:param max_tokens: The max number of tokens to generate from the response
"""
preset_file = get_openai_preset_path(presetname)
endpoint_file = get_endpoint_config(endpoint)
self.api_key = endpoint_file.get("apiKey", "")
print("Api key found: " + self.api_key)
self.endpoint_url = endpoint_file["endpoint"]
self.model_name = model_name
self.model_name = endpoint_file["modelNameToSendToAPI"]
self.is_busy: bool = False
self.truncate_property_name = get_config_property_if_exists("truncateLengthPropertyName", api_type_config)
self.stream_property_name = get_config_property_if_exists("streamPropertyName", api_type_config)
self.max_token_property_name = get_config_property_if_exists("maxNewTokensPropertyName", api_type_config)

if not os.path.exists(preset_file):
raise FileNotFoundError(f'The preset file {preset_file} does not exist.')

with open(preset_file) as file:
preset = json.load(file)

self._gen_input = OpenAiApiPresets(**preset)
self._gen_input_raw = OpenAiApiPresets(**preset)
self._gen_input = self._gen_input_raw.to_json()
# Add optional fields if they are not None
if self.truncate_property_name:
self._gen_input[self.truncate_property_name] = truncate_length
if self.stream_property_name:
self._gen_input[self.stream_property_name] = stream
if self.max_token_property_name:
self._gen_input[self.max_token_property_name] = max_tokens

self.endpoint: str = endpoint_file["endpoint"]
self.stream: bool = stream
self._api = OpenAiChatCompletionsApi(self.api_key, self.endpoint_url)
self._api = OpenAiChatCompletionsApi(self.api_key, self.endpoint_url, api_type_config)

def get_response_from_llm(self, conversation: List[Dict[str, str]]) -> Union[
Generator[str, None, None], Dict[str, Any], None]:
Expand All @@ -66,14 +80,13 @@ def get_response_from_llm(self, conversation: List[Dict[str, str]]) -> Union[
try:
self.is_busy = True

if self.stream:
self._gen_input.stream = True
if self.stream and self.stream_property_name:
return self._api.invoke_streaming(messages=corrected_conversation, endpoint=self.endpoint,
model=self.model_name, params=self._gen_input.to_json())
model=self.model_name, params=self._gen_input)
else:
result = self._api.invoke_non_streaming(messages=corrected_conversation, endpoint=self.endpoint,
model_name=self.model_name,
params=self._gen_input.to_json())
params=self._gen_input)
print("######################################")
print("Non-streaming output: ", result)
print("######################################")
Expand All @@ -98,7 +111,7 @@ class OpenAiChatCompletionsApi:
A class that encapsulates the functionality to interact with the OpenAI API.
"""

def __init__(self, api_key: str, endpoint: str):
def __init__(self, api_key: str, endpoint: str, api_type_config):
"""
Initializes the OpenAiChatCompletionsApi with the base URL and default headers.
Expand Down Expand Up @@ -129,7 +142,8 @@ def invoke_streaming(self, messages: List[Dict[str, str]], endpoint: str, model:
:return: A generator yielding chunks of the response in SSE format.
"""
url: str = f"{endpoint}/v1/chat/completions"
data: Dict[str, Any] = {"model": model, "stream": True, "messages": messages, **(params or {})}
data: Dict[str, Any] = {"model": model, "messages": messages, **(params or {})}

add_user_assistant = get_is_chat_complete_add_user_assistant()
add_missing_assistant = get_is_chat_complete_add_missing_assistant()
print(f"Streaming flow!")
Expand All @@ -148,7 +162,6 @@ def sse_format(data: str) -> str:
first_chunk_processed = False
max_buffer_length = 100
for chunk in r.iter_content(chunk_size=1024, decode_unicode=True):
print("Chunk received: {}".format(chunk))
buffer += chunk
while "data:" in buffer:
data_pos = buffer.find("data:")
Expand Down
34 changes: 26 additions & 8 deletions Middleware/llmapis/open_ai_llm_completions_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from Middleware.models.open_ai_api_presets import OpenAiApiPresets
from Middleware.utilities.config_utils import get_openai_preset_path, get_endpoint_config, \
get_is_chat_complete_add_user_assistant, get_is_chat_complete_add_missing_assistant
get_is_chat_complete_add_user_assistant, get_is_chat_complete_add_missing_assistant, get_config_property_if_exists
from Middleware.utilities.text_utils import return_brackets_in_string


Expand All @@ -17,30 +17,44 @@ class OpenAiLlmCompletionsApiService:
A service class that provides compatibility with OpenAI's API for interacting with LLMs.
"""

def __init__(self, endpoint: str, model_name: str, presetname: str, stream: bool = False):
def __init__(self, endpoint: str, presetname: str, api_type_config, truncate_length, max_tokens,
stream: bool = False):
"""
Initializes the OpenAiLlmCompletionsApiService with the given configuration.
:param endpoint: The API endpoint URL for the LLM service.
:param model_name: The name of the model to use.
:param presetname: The name of the preset file containing API parameters.
:param stream: A boolean indicating whether to use streaming or not.
:param api_type_config: The config file for the specified apiType in the Endpoint
:param truncate_length: The max context length of the model, if it applies
:param max_tokens: The max number of tokens to generate from the response
"""
preset_file = get_openai_preset_path(presetname)
endpoint_file = get_endpoint_config(endpoint)
self.model_name = model_name
self.api_key = endpoint_file.get("apiKey", "")
print("API key found: " + self.api_key)
print("Api key found: " + self.api_key)
self.endpoint_url = endpoint_file["endpoint"]
self.model_name = endpoint_file["modelNameToSendToAPI"]
self.is_busy: bool = False
self.truncate_property_name = get_config_property_if_exists("truncateLengthPropertyName", api_type_config)
self.stream_property_name = get_config_property_if_exists("streamPropertyName", api_type_config)
self.max_token_property_name = get_config_property_if_exists("maxNewTokensPropertyName", api_type_config)

if not os.path.exists(preset_file):
raise FileNotFoundError(f'The preset file {preset_file} does not exist.')

with open(preset_file) as file:
preset = json.load(file)

self._gen_input = OpenAiApiPresets(**preset)
self._gen_input_raw = OpenAiApiPresets(**preset)
self._gen_input = self._gen_input_raw.to_json()
# Add optional fields if they are not None
if self.truncate_property_name:
self._gen_input[self.truncate_property_name] = truncate_length
if self.stream_property_name:
self._gen_input[self.stream_property_name] = stream
if self.max_token_property_name:
self._gen_input[self.max_token_property_name] = max_tokens

self.endpoint: str = endpoint_file["endpoint"]
self.stream: bool = stream
Expand All @@ -66,11 +80,11 @@ def get_response_from_llm(self, system_prompt: str, prompt: str) -> Union[
if self.stream:
return self._api.invoke_streaming(prompt=full_prompt, endpoint=self.endpoint,
model_name=self.model_name,
params=self._gen_input.to_json())
params=self._gen_input)
else:
result = self._api.invoke_non_streaming(prompt=full_prompt, endpoint=self.endpoint,
model_name=self.model_name,
params=self._gen_input.to_json())
params=self._gen_input)
print("######################################")
print("Non-streaming output: ", result)
print("######################################")
Expand Down Expand Up @@ -231,6 +245,10 @@ def invoke_non_streaming(self, prompt: str, endpoint: str, model_name: str,
for attempt in range(retries):
try:
print(f"Non-Streaming flow! Attempt: {attempt + 1}")
print("Headers: ")
print(json.dumps(self.headers))
print("Data: ")
print(json.dumps(data))
response = self.session.post(url, headers=self.headers, json=data, timeout=14400)
response.raise_for_status() # Raises HTTPError for bad responses

Expand Down
29 changes: 16 additions & 13 deletions Middleware/services/llm_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,29 @@
from Middleware.llmapis.open_ai_llm_completions_api import OpenAiLlmCompletionsApiService
from Middleware.models.llm_handler import LlmHandler
from Middleware.utilities.config_utils import get_chat_template_name, \
load_config, get_model_config_path
get_endpoint_config, get_api_type_config


class LlmHandlerService:

def __init__(self):
self.llm_handler = None

def initialize_llm_handler(self, config_data, preset, endpoint, stream=False):
llm_type = config_data["type"]
add_generation_prompt = config_data["addGenerationPrompt"]
model_name = config_data["modelNameToSendToAPI"]
def initialize_llm_handler(self, config_data, preset, endpoint, stream, truncate_length, max_tokens):
print("Initialize llm hander config_data: {}".format(config_data))
add_generation_prompt = config_data.get("addGenerationPrompt", False)
api_type_config = get_api_type_config(config_data["apiTypeConfigFileName"])
llm_type = api_type_config["type"]
if llm_type == "openAIV1Completion":
print('Loading v1 Completions endpoint: ' + endpoint)
llm = OpenAiLlmCompletionsApiService(endpoint=endpoint, model_name=model_name, presetname=preset,
stream=stream)
llm = OpenAiLlmCompletionsApiService(endpoint=endpoint, presetname=preset,
stream=stream, api_type_config=api_type_config,
truncate_length=truncate_length, max_tokens=max_tokens)
elif llm_type == "openAIChatCompletion":
print('Loading chat Completions endpoint: ' + endpoint)
llm = OpenAiLlmChatCompletionsApiService(endpoint=endpoint, model_name=model_name, presetname=preset,
stream=stream)
llm = OpenAiLlmChatCompletionsApiService(endpoint=endpoint, presetname=preset,
stream=stream, api_type_config=api_type_config,
truncate_length=truncate_length, max_tokens=max_tokens)
else:
raise ValueError(f"Unsupported LLM type: {llm_type}")

Expand All @@ -35,10 +38,10 @@ def initialize_llm_handler(self, config_data, preset, endpoint, stream=False):

return self.llm_handler

def load_model_from_config(self, config_name, preset, stream=False):
def load_model_from_config(self, config_name, preset, stream=False, truncate_length=4096, max_tokens=400):
try:
config_file = get_model_config_path(config_name)
config_data = load_config(config_file)
return self.initialize_llm_handler(config_data, preset, config_name, stream)
print("Loading model from: " + config_name)
config_file = get_endpoint_config(config_name)
return self.initialize_llm_handler(config_file, preset, config_name, stream, truncate_length, max_tokens)
except Exception as e:
print(f"Error loading model from config: ", e)
31 changes: 18 additions & 13 deletions Middleware/utilities/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ def load_config(config_file):
return config_data


def get_config_property_if_exists(config_property, config_data):
if config_data.get(config_property) and config_data.get(config_property) != "":
return config_data.get(config_property)
else:
return None


def get_current_username():
"""
Retrieves the current username from the configuration.
Expand Down Expand Up @@ -212,19 +219,6 @@ def get_categories_config():
return load_config(config_path)


def get_model_config_path(config_name):
"""
Retrieves the file path to a model configuration file based on the endpoint name.
:param config_name: The name of the endpoint configuration.
:return: The full path to the model configuration file.
"""
endpoint_config = get_config_path('Endpoints', config_name)
endpoint_data = load_config(endpoint_config)
model_config_path = get_config_path('ModelsConfigs', endpoint_data['modelConfigFileName'])
return model_config_path


def get_openai_preset_path(config_name):
"""
Retrieves the file path to a preset configuration file.
Expand All @@ -246,6 +240,17 @@ def get_endpoint_config(endpoint):
return load_config(endpoint_file)


def get_api_type_config(api_type):
"""
Retrieves the endpoint configuration based on the endpoint name.
:param endpoint: The name of the endpoint configuration.
:return: The full path to the endpoint configuration file.
"""
api_type_file = get_config_path('ApiTypes', api_type)
return load_config(api_type_file)


def get_template_config_path(template_file_name):
"""
Constructs the file path for a prompt template configuration file.
Expand Down
2 changes: 1 addition & 1 deletion Middleware/utilities/prompt_extraction_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def extract_discussion_id(messages: List[Dict[str, str]]) -> Optional[str]:
Returns:
Optional[str]: The extracted numeric discussion ID, or None if not found.
"""
pattern = f'{re.escape(discussion_identifiers["discussion_id_start"])}(\\d+){re.escape(discussion_identifiers["discussion_id_end"])}'
pattern = f'{re.escape(discussion_identifiers["discussion_id_start"])}(.*?){re.escape(discussion_identifiers["discussion_id_end"])}'
for message in messages:
match = re.search(pattern, message['content'])
if match:
Expand Down
4 changes: 2 additions & 2 deletions Middleware/workflows/categorization/prompt_categorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def get_prompt_category(self, prompt, stream):
Get the category of the prompt and run the appropriate workflow.
Args:
prompt (str): The input prompt to categorize.
stream (bool): Whether to stream the output. Default is False.
prompt: The input prompt to categorize.
stream: Whether to stream the output. Default is False.
Returns:
str: The result of the workflow execution.
Expand Down
26 changes: 25 additions & 1 deletion Middleware/workflows/managers/workflow_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,11 @@ def _process_section(self, config: Dict, messages: List[Dict[str, str]] = None,
else:
self.llm_handler = self.llm_handler_service.load_model_from_config(config["endpointName"],
preset,
stream)
stream,
config.get("maxContextTokenSize",
4096),
config.get("maxResponseSizeInTokens",
400))
if "endpointName" not in config:
self.llm_handler = LlmHandler(None, get_chat_template_name(), 0, 0, True)

Expand Down Expand Up @@ -182,6 +186,9 @@ def _process_section(self, config: Dict, messages: List[Dict[str, str]] = None,
if config["type"] == "GetCurrentSummaryFromFile":
print("Getting current summary from File")
return self.handle_get_current_summary_from_file(messages)
if config["type"] == "GetCurrentMemoryFromFile":
print("Getting current memories from File")
return self.handle_get_current_summary_from_file(messages)
if config["type"] == "WriteCurrentSummaryToFileAndReturnIt":
print("Writing current summary to file")
return prompt_processor_service.save_summary_to_file(config,
Expand Down Expand Up @@ -297,3 +304,20 @@ def handle_get_current_summary_from_file(self, messages):
return "There is not yet a summary file"

return extract_text_blocks_from_hashed_chunks(current_summary)

def handle_get_current_memories_from_file(self, messages):
"""
Retrieves the current summary from a file based on the user's prompt.
:param messages: List of message dictionaries.
:return: The current summary extracted from the file or a message indicating the absence of a summary file.
"""
discussion_id = extract_discussion_id(messages)
filepath = get_discussion_memory_file_path(discussion_id)

current_memories = read_chunks_with_hashes(filepath)

if current_memories is None or len(current_memories) == 0:
return "There are not yet any memories"

return extract_text_blocks_from_hashed_chunks(current_memories)
Loading

0 comments on commit 0307b07

Please sign in to comment.