Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
361 changes: 361 additions & 0 deletions docs/source/notebooks/models.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions docs/source/toc.segment
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
:caption: Introduction

./notebooks/getting_started
./notebooks/models
./notebooks/datasets
```

Expand Down
8 changes: 7 additions & 1 deletion langchain_benchmarks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
from langchain_benchmarks.model_registration import model_registry
from langchain_benchmarks.registration import registry
from langchain_benchmarks.utils._langsmith import (
clone_public_dataset,
download_public_dataset,
)

# Please keep this list sorted!
__all__ = ["clone_public_dataset", "download_public_dataset", "registry"]
__all__ = [
"clone_public_dataset",
"download_public_dataset",
"model_registry",
"registry",
]
152 changes: 152 additions & 0 deletions langchain_benchmarks/model_registration.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No gpt-4 or gpt-4-1106-preview? What about claude?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah we can add all of those -- just wanted to get something initial in for a review / consensus that we want this

Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
from __future__ import annotations

from langchain_benchmarks.schema import ModelRegistry, RegisteredModel

_OpenAIModels = [
RegisteredModel(
provider="openai",
name="gpt-3.5-turbo-1106",
type="chat",
description=(
"The latest GPT-3.5 Turbo model with improved instruction following, "
"JSON mode, reproducible outputs, parallel function calling, and more. "
"Returns a maximum of 4,096 output tokens."
),
params={
"model": "gpt-3.5-turbo-1106",
},
),
RegisteredModel(
provider="openai",
name="gpt-3.5-turbo",
type="chat",
description="Currently points to gpt-3.5-turbo-0613.",
params={
"model": "gpt-3.5-turbo",
},
),
RegisteredModel(
provider="openai",
name="gpt-3.5-turbo-16k",
type="chat",
description="Currently points to gpt-3.5-turbo-0613.",
params={
"model": "gpt-3.5-turbo-16k",
},
),
RegisteredModel(
provider="openai",
name="gpt-3.5-turbo-instruct",
type="llm",
description=(
"Similar capabilities as text-davinci-003 but compatible with legacy "
"Completions endpoint and not Chat Completions."
),
params={
"model": "gpt-3.5-turbo-instruct",
},
),
RegisteredModel(
provider="openai",
name="gpt-3.5-turbo-0613",
type="chat",
description=(
"Legacy Snapshot of gpt-3.5-turbo from June 13th 2023. "
"Will be deprecated on June 13, 2024."
),
params={
"model": "gpt-3.5-turbo-0613",
},
),
RegisteredModel(
provider="openai",
name="gpt-3.5-turbo-16k-0613",
type="chat",
description=(
"Legacy Snapshot of gpt-3.5-16k-turbo from June 13th 2023. "
"Will be deprecated on June 13, 2024."
),
params={
"model": "gpt-3.5-turbo-16k-0613",
},
),
RegisteredModel(
provider="openai",
name="gpt-3.5-turbo-0301",
type="chat",
description=(
"Legacy Snapshot of gpt-3.5-turbo from March 1st 2023. "
"Will be deprecated on June 13th 2024."
),
params={
"model": "gpt-3.5-turbo-0301",
},
),
RegisteredModel(
provider="openai",
name="text-davinci-003",
type="llm",
description=(
"Legacy Can do language tasks with better quality and consistency than "
"the curie, babbage, or ada models. Will be deprecated on Jan 4th 2024."
),
params={
"model": "text-davinci-003",
},
),
RegisteredModel(
provider="openai",
name="text-davinci-002",
type="llm",
description=(
"Legacy Similar capabilities to text-davinci-003 but trained with "
"supervised fine-tuning instead of reinforcement learning. "
"Will be deprecated on Jan 4th 2024."
),
params={
"model": "text-davinci-002",
},
),
RegisteredModel(
provider="openai",
name="code-davinci-002",
type="llm",
description="Legacy Optimized for code-completion tasks. Will be deprecated "
"on Jan 4th 2024.",
params={
"model": "code-davinci-002",
},
),
]

_FireworksModels = [
RegisteredModel(
provider="fireworks",
name="llama-v2-7b-chat-fw",
type="chat",
description="7b parameter LlamaChat model",
params={
"model": "accounts/fireworks/models/llama-v2-7b-chat",
},
),
RegisteredModel(
provider="fireworks",
name="llama-v2-13b-chat-fw",
type="chat",
description="13b parameter LlamaChat model",
params={
"model": "accounts/fireworks/models/llama-v2-13b-chat",
},
),
RegisteredModel(
provider="fireworks",
name="llama-v2-70b-chat-fw",
type="chat",
description="70b parameter LlamaChat model",
params={
"model": "accounts/fireworks/models/llama-v2-70b-chat",
},
),
]

model_registry = ModelRegistry(registered_models=_OpenAIModels + _FireworksModels)
193 changes: 192 additions & 1 deletion langchain_benchmarks/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@
from __future__ import annotations

import dataclasses
import importlib
import urllib
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Type, Union

from langchain.prompts import ChatPromptTemplate
from langchain.schema import BaseRetriever
from langchain.schema.document import Document
from langchain.schema.embeddings import Embeddings
from langchain.tools import BaseTool
from langchain_core.language_models import BaseChatModel, BaseLanguageModel
from pydantic import BaseModel
from tabulate import tabulate
from typing_extensions import Literal


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -153,6 +156,7 @@ def __post_init__(self) -> None:
raise ValueError(
f"Duplicate task name {task.name}. " f"Task names must be unique."
)
seen_names.add(task.name)

def _repr_html_(self) -> str:
"""Return an HTML representation of the registry."""
Expand Down Expand Up @@ -210,3 +214,190 @@ def add(self, task: BaseTask) -> None:
if not isinstance(task, BaseTask):
raise TypeError("Only tasks can be added to the registry.")
self.tasks.append(task)


Provider = Literal["fireworks", "openai"]
ModelType = Literal["chat", "llm"]
AUTHORIZED_NAMESPACES = {"langchain"}


def _get_model_class_from_path(
path: str
) -> Union[Type[BaseChatModel], Type[BaseLanguageModel]]:
"""Get the class of the model."""
module_name, attribute_name = path.rsplit(".", 1)
top_namespace = path.split(".")[0]

if top_namespace not in AUTHORIZED_NAMESPACES:
raise ValueError(
f"Unauthorized namespace {top_namespace}. "
f"Authorized namespaces are: {AUTHORIZED_NAMESPACES}"
)

# Import the module dynamically
module = importlib.import_module(module_name)
model_class = getattr(module, attribute_name)
if not issubclass(model_class, (BaseLanguageModel, BaseChatModel)):
raise ValueError(
f"Model class {model_class} is not a subclass of BaseLanguageModel"
)
return model_class


def _get_default_path(provider: str, type_: ModelType) -> str:
"""Get the default path for a model."""
paths = {
("fireworks", "chat"): "langchain.chat_models.fireworks.ChatFireworks",
("fireworks", "llm"): "langchain.language_models.fireworks.Fireworks",
("openai", "chat"): "langchain.chat_models.openai.ChatOpenAI",
("openai", "llm"): "langchain.language_models.openai.OpenAI",
}

if (provider, type_) not in paths:
raise ValueError(f"Unknown provider {provider} and type {type_}")

return paths[(provider, type_)]


@dataclasses.dataclass(frozen=True)
class RegisteredModel:
"""Descriptive information about a model.

This information can be used to instantiate the underlying model.
"""

name: str
provider: Provider
description: str
params: Dict[str, Any]
type: ModelType
# Path to the model class.
# For example, "langchain.chat_models.anthropic import ChatAnthropicModel"
path: Optional[str] = None # If not provided, will use default path

def get_model(
self, *, model_params: Optional[Dict[str, Any]] = None
) -> Union[BaseChatModel, BaseLanguageModel]:
"""Get the class of the model."""
all_params = {**self.params, **(model_params or {})}
model_class = _get_model_class_from_path(self.model_path)
return model_class(**all_params)

@property
def model_path(self) -> str:
"""Get the path of the model."""
return self.path or _get_default_path(self.provider, self.type)

@property
def _table(self) -> List[List[str]]:
"""Return a table representation of the environment."""
return [
["name", self.name],
["type", self.type],
["provider", self.provider],
["description", self.description],
["model_path", self.model_path],
]

def _repr_html_(self) -> str:
"""Return an HTML representation of the environment."""
return tabulate(
self._table,
tablefmt="unsafehtml",
)


StrFilter = Union[None, str, Sequence[str]]


def _is_in_filter(actual_value: str, filter_value: StrFilter) -> bool:
"""Filter for a string attribute."""
if filter_value is None:
return True

if isinstance(filter_value, str):
return actual_value == filter_value

return actual_value in filter_value


@dataclasses.dataclass(frozen=False)
class ModelRegistry:
registered_models: Sequence[RegisteredModel]

def __post_init__(self) -> None:
"""Validate that all the tasks have unique names and IDs."""
seen_names = set()
for model in self.registered_models:
if model.name in seen_names:
raise ValueError(
f"Duplicate model name {model.name}. " f"Task names must be unique."
)
seen_names.add(model.name)

def get_model(self, name: str) -> Optional[RegisteredModel]:
"""Get model info."""
return next(model for model in self.registered_models if model.name == name)

def filter(
self,
*,
type: StrFilter = None,
name: StrFilter = None,
provider: StrFilter = None,
) -> ModelRegistry:
"""Filter the tasks in the registry."""
models = self.registered_models
selected_models = []

for model in models:
if not _is_in_filter(model.type, type):
continue
if not _is_in_filter(model.name, name):
continue
if not _is_in_filter(model.provider, provider):
continue
selected_models.append(model)
return ModelRegistry(registered_models=selected_models)

def _repr_html_(self) -> str:
"""Return an HTML representation of the registry."""
headers = [
"Name",
"Type",
"Provider",
"Description",
]
table = [
[
model.name,
model.type,
model.provider,
model.description,
]
for model in self.registered_models
]
return tabulate(table, headers=headers, tablefmt="unsafehtml")

def __len__(self) -> int:
"""Return the number of tasks in the registry."""
return len(self.registered_models)

def __iter__(self) -> Iterable[RegisteredModel]:
"""Iterate over the tasks in the registry."""
return iter(self.registered_models)

def __getitem__(
self, key: Union[int, str]
) -> Union[RegisteredModel, ModelRegistry]:
"""Get an environment from the registry."""
if isinstance(key, slice):
return ModelRegistry(registered_models=self.registered_models[key])
elif isinstance(key, (int, str)):
# If key is an integer, return the corresponding environment
if isinstance(key, str):
return self.get_model(key)
else:
return self.registered_models[key]
else:
raise TypeError("Key must be an integer or a slice.")
Loading