1- def get_client (model_str ):
1+ from llama_index .core .llms .function_calling import FunctionCallingLLM
2+
3+
4+ def get_client (model_str : str ) -> FunctionCallingLLM :
25 split_result = model_str .split (":" )
36 if len (split_result ) == 1 :
47 # Assume default provider to be openai
@@ -11,6 +14,7 @@ def get_client(model_str):
1114 else :
1215 provider = split_result [0 ]
1316 model_name = split_result [1 ]
17+
1418 if provider == "openai" :
1519 from llama_index .llms .openai import OpenAI
1620
@@ -19,7 +23,11 @@ def get_client(model_str):
1923 from llama_index .llms .anthropic import Anthropic
2024
2125 return Anthropic (model = model_name )
22- elif provider == "mixtral" or provider == "groq" :
26+ elif provider == "mistral" :
27+ from llama_index .llms .mistralai import MistralAI
28+
29+ return MistralAI (model = model_name )
30+ elif provider == "groq" :
2331 from llama_index .llms .groq import Groq
2432
2533 return Groq (model = model_name )
@@ -30,10 +38,11 @@ def get_client(model_str):
3038 return Ollama (model = model_name )
3139 elif provider == "bedrock" :
3240 from llama_index .llms .bedrock import Bedrock
33-
41+
3442 return Bedrock (model = model_name )
3543 elif provider == "cerebras" :
3644 from llama_index .llms .cerebras import Cerebras
3745
3846 return Cerebras (model = model_name )
39-
47+
48+ raise ValueError (f"Provider { provider } not found" )
0 commit comments