Skip to content

openai llm as default #588

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 0 additions & 54 deletions backend/src/gemini_llm.py

This file was deleted.

44 changes: 0 additions & 44 deletions backend/src/generate_graphDocuments_from_llm.py

This file was deleted.

48 changes: 0 additions & 48 deletions backend/src/groq_llama3_llm.py

This file was deleted.

49 changes: 35 additions & 14 deletions backend/src/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
from src.shared.constants import MODEL_VERSIONS


def get_llm(model_version: str):
def get_llm(model: str):
"""Retrieve the specified language model based on the model name."""
env_key = "LLM_MODEL_CONFIG_" + model_version
env_key = "LLM_MODEL_CONFIG_" + model
env_value = os.environ.get(env_key)
logging.info("Model: {}".format(env_key))
if "gemini" in model_version:
if "gemini" in model:
credentials, project_id = google.auth.default()
model_name = MODEL_VERSIONS[model_version]
model_name = MODEL_VERSIONS[model]
llm = ChatVertexAI(
model_name=model_name,
convert_system_message_to_human=True,
Expand All @@ -40,15 +40,15 @@ def get_llm(model_version: str):
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
},
)
elif "openai" in model_version:
model_name = MODEL_VERSIONS[model_version]
elif "openai" in model:
model_name = MODEL_VERSIONS[model]
llm = ChatOpenAI(
api_key=os.environ.get("OPENAI_API_KEY"),
model=model_name,
temperature=0,
)

elif "azure" in model_version:
elif "azure" in model:
model_name, api_endpoint, api_key, api_version = env_value.split(",")
llm = AzureChatOpenAI(
api_key=api_key,
Expand All @@ -60,21 +60,21 @@ def get_llm(model_version: str):
timeout=None,
)

elif "anthropic" in model_version:
elif "anthropic" in model:
model_name, api_key = env_value.split(",")
llm = ChatAnthropic(
api_key=api_key, model=model_name, temperature=0, timeout=None
)

elif "fireworks" in model_version:
elif "fireworks" in model:
model_name, api_key = env_value.split(",")
llm = ChatFireworks(api_key=api_key, model=model_name)

elif "groq" in model_version:
elif "groq" in model:
model_name, base_url, api_key = env_value.split(",")
llm = ChatGroq(api_key=api_key, model_name=model_name, temperature=0)

elif "bedrock" in model_version:
elif "bedrock" in model:
model_name, aws_access_key, aws_secret_key, region_name = env_value.split(",")
bedrock_client = boto3.client(
service_name="bedrock-runtime",
Expand All @@ -87,17 +87,27 @@ def get_llm(model_version: str):
client=bedrock_client, model_id=model_name, model_kwargs=dict(temperature=0)
)

elif "ollama" in model_version:
elif "ollama" in model:
model_name, base_url = env_value.split(",")
llm = ChatOllama(base_url=base_url, model=model_name)

else:
elif "diffbot" in model:
model_name = "diffbot"
llm = DiffbotGraphTransformer(
diffbot_api_key=os.environ.get("DIFFBOT_API_KEY"),
extract_types=["entities", "facts"],
)
logging.info(f"Model created - Model Version: {model_version}")

else:
model_name, api_endpoint, api_key = env_value.split(",")
llm = ChatOpenAI(
api_key=api_key,
base_url=api_endpoint,
model=model_name,
temperature=0,
)

logging.info(f"Model created - Model Version: {model}")
return llm, model_name


Expand Down Expand Up @@ -162,8 +172,19 @@ def get_graph_document_list(


def get_graph_from_llm(model, chunkId_chunkDoc_list, allowedNodes, allowedRelationship):

llm, model_name = get_llm(model)
combined_chunk_document_list = get_combined_chunks(chunkId_chunkDoc_list)

if allowedNodes is None or allowedNodes=="":
allowedNodes =[]
else:
allowedNodes = allowedNodes.split(',')
if allowedRelationship is None or allowedRelationship=="":
allowedRelationship=[]
else:
allowedRelationship = allowedRelationship.split(',')

graph_document_list = get_graph_document_list(
llm, combined_chunk_document_list, allowedNodes, allowedRelationship
)
Expand Down
4 changes: 2 additions & 2 deletions backend/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from src.graphDB_dataAccess import graphDBdataAccess
from src.document_sources.local_file import get_documents_from_file_by_path
from src.entities.source_node import sourceNode
from src.generate_graphDocuments_from_llm import generate_graphDocuments
from src.llm import get_graph_from_llm
from src.document_sources.gcs_bucket import *
from src.document_sources.s3_bucket import *
from src.document_sources.wikipedia import *
Expand Down Expand Up @@ -373,7 +373,7 @@ def processing_chunks(chunkId_chunkDoc_list,graph,uri, userName, password, datab

update_embedding_create_vector_index( graph, chunkId_chunkDoc_list, file_name)
logging.info("Get graph document list from models")
graph_documents = generate_graphDocuments(model, graph, chunkId_chunkDoc_list, allowedNodes, allowedRelationship)
graph_documents = get_graph_from_llm(model, chunkId_chunkDoc_list, allowedNodes, allowedRelationship)
cleaned_graph_documents = handle_backticks_nodes_relationship_id_type(graph_documents)
save_graphDocuments_in_neo4j(graph, cleaned_graph_documents)
chunks_and_graphDocuments_list = get_chunk_and_graphDocument(cleaned_graph_documents, chunkId_chunkDoc_list)
Expand Down
24 changes: 0 additions & 24 deletions backend/src/openAI_llm.py

This file was deleted.