Skip to content

Commit a81fb4e

Browse files
committed
Update factory
1 parent 03c7024 commit a81fb4e

File tree

3 files changed

+21
-6
lines changed

3 files changed

+21
-6
lines changed

langchain_benchmarks/parrot.png

18.5 KB
Loading

langchain_benchmarks/schema.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,15 @@ class RetrievalTask(BaseTask):
143143
"""A function that returns the documents to be indexed."""
144144
retriever_factories: Dict[
145145
str, Callable[[Embeddings], BaseRetriever]
146-
] = dataclasses.field(default_factory=dict) # noqa: F821
146+
] = dataclasses.field(
147+
default_factory=dict
148+
) # noqa: F821
147149
"""Factories that index the docs using the specified strategy."""
148150
architecture_factories: Dict[
149151
str, Callable[[Embeddings], BaseRetriever]
150-
] = dataclasses.field(default_factory=dict) # noqa: F821
152+
] = dataclasses.field(
153+
default_factory=dict
154+
) # noqa: F821
151155
"""Factories methods that help build some off-the-shelf architectures。"""
152156

153157
@property
@@ -251,7 +255,7 @@ def add(self, task: BaseTask) -> None:
251255
self.tasks.append(task)
252256

253257

254-
Provider = Literal["fireworks", "openai", "anthropic"]
258+
Provider = Literal["fireworks", "openai", "anthropic", "anyscale"]
255259
ModelType = Literal["chat", "llm"]
256260
AUTHORIZED_NAMESPACES = {"langchain"}
257261

@@ -284,6 +288,8 @@ def _get_default_path(provider: str, type_: ModelType) -> str:
284288
paths = {
285289
("fireworks", "chat"): "langchain.chat_models.fireworks.ChatFireworks",
286290
("fireworks", "llm"): "langchain.llms.fireworks.Fireworks",
291+
("anyscale", "chat"): "langchain.chat_models.anyscale.ChatAnyscale",
292+
("anyscale", "llm"): "langchain.llms.anyscale.Anyscale",
287293
("openai", "chat"): "langchain.chat_models.openai.ChatOpenAI",
288294
("openai", "llm"): "langchain.llms.openai.OpenAI",
289295
("anthropic", "chat"): "langchain.chat_models.anthropic.ChatAnthropic",
@@ -303,6 +309,8 @@ def _get_default_url(provider: str, type_: ModelType) -> Optional[str]:
303309
return "https://platform.openai.com/docs/models"
304310
elif provider == "anthropic":
305311
return "https://docs.anthropic.com/claude/reference/selecting-a-model"
312+
elif provider == "anyscale":
313+
return "https://docs.endpoints.anyscale.com/category/supported-models"
306314
else:
307315
return None
308316

langchain_benchmarks/tool_usage/agents/openai_functions.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,16 @@ def __init__(
9191

9292
def _create_model(self) -> Union[BaseChatModel, BaseLanguageModel]:
9393
if isinstance(self.model, RegisteredModel):
94-
return self.model.get_model(model_params={"temperature": 0, "seed": 0})
94+
return self.model.get_model(
95+
model_params={"temperature": 0, "model_kwargs": {"seed": 0}}
96+
)
9597
else:
96-
return ChatOpenAI(self.model, temperature=0, model_kwargs={"seed": 0})
98+
return ChatOpenAI(model=self.model, temperature=0, model_kwargs={"seed": 0})
99+
100+
def create(self) -> Runnable:
101+
"""Agent Executor"""
102+
# For backwards compatibility
103+
return self()
97104

98105
def __call__(self) -> Runnable:
99106
model = self._create_model()
@@ -102,7 +109,7 @@ def __call__(self) -> Runnable:
102109

103110
model = _bind_tools(model, env.tools)
104111

105-
if rate_limiting:
112+
if self.rate_limiter is not None:
106113
# Rate limited model
107114
model = rate_limiting.with_rate_limit(model, self.rate_limiter)
108115

0 commit comments

Comments
 (0)