Skip to content

Commit 892729a

Browse files
authored
Merge pull request #68 from OpenGenerativeAI/nico/simplify
Adds back mistral support
2 parents 356146a + f167e52 commit 892729a

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

agent/llm.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
def get_client(model_str):
1+
from llama_index.core.llms.function_calling import FunctionCallingLLM
2+
3+
4+
def get_client(model_str: str) -> FunctionCallingLLM:
25
split_result = model_str.split(":")
36
if len(split_result) == 1:
47
# Assume default provider to be openai
@@ -11,6 +14,7 @@ def get_client(model_str):
1114
else:
1215
provider = split_result[0]
1316
model_name = split_result[1]
17+
1418
if provider == "openai":
1519
from llama_index.llms.openai import OpenAI
1620

@@ -19,7 +23,11 @@ def get_client(model_str):
1923
from llama_index.llms.anthropic import Anthropic
2024

2125
return Anthropic(model=model_name)
22-
elif provider == "mixtral" or provider == "groq":
26+
elif provider == "mistral":
27+
from llama_index.llms.mistralai import MistralAI
28+
29+
return MistralAI(model=model_name)
30+
elif provider == "groq":
2331
from llama_index.llms.groq import Groq
2432

2533
return Groq(model=model_name)
@@ -30,10 +38,11 @@ def get_client(model_str):
3038
return Ollama(model=model_name)
3139
elif provider == "bedrock":
3240
from llama_index.llms.bedrock import Bedrock
33-
41+
3442
return Bedrock(model=model_name)
3543
elif provider == "cerebras":
3644
from llama_index.llms.cerebras import Cerebras
3745

3846
return Cerebras(model=model_name)
39-
47+
48+
raise ValueError(f"Provider {provider} not found")

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,4 @@ llama-index-llms-groq
2424
llama-index-llms-ollama
2525
llama-index-llms-bedrock
2626
llama-index-llms-cerebras
27+
llama-index-llms-mistralai

0 commit comments

Comments
 (0)