Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: store handle in configs #2299

Merged
merged 3 commits into from
Dec 21, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions letta/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ def get_model_context_window(self, model_name: str) -> Optional[int]:
def provider_tag(self) -> str:
"""String representation of the provider for display purposes"""
raise NotImplementedError

def get_handle(self, model_name: str) -> str:
return f"{self.name}/{model_name}"



class LettaProvider(Provider):
Expand All @@ -40,6 +44,7 @@ def list_llm_models(self) -> List[LLMConfig]:
model_endpoint_type="openai",
model_endpoint="https://inference.memgpt.ai",
context_window=16384,
handle=self.get_handle("letta-free")
)
]

Expand All @@ -51,6 +56,7 @@ def list_embedding_models(self):
embedding_endpoint="https://embeddings.memgpt.ai",
embedding_dim=1024,
embedding_chunk_size=300,
handle=self.get_handle("letta-free")
)
]

Expand Down Expand Up @@ -115,7 +121,7 @@ def list_llm_models(self) -> List[LLMConfig]:
# continue

configs.append(
LLMConfig(model=model_name, model_endpoint_type="openai", model_endpoint=self.base_url, context_window=context_window_size)
LLMConfig(model=model_name, model_endpoint_type="openai", model_endpoint=self.base_url, context_window=context_window_size, handle=self.get_handle(model_name))
)

# for OpenAI, sort in reverse order
Expand All @@ -135,6 +141,7 @@ def list_embedding_models(self) -> List[EmbeddingConfig]:
embedding_endpoint="https://api.openai.com/v1",
embedding_dim=1536,
embedding_chunk_size=300,
handle=self.get_handle("text-embedding-ada-002")
)
]

Expand Down Expand Up @@ -163,6 +170,7 @@ def list_llm_models(self) -> List[LLMConfig]:
model_endpoint_type="anthropic",
model_endpoint=self.base_url,
context_window=model["context_window"],
handle=self.get_handle(model["name"])
)
)
return configs
Expand Down Expand Up @@ -195,6 +203,7 @@ def list_llm_models(self) -> List[LLMConfig]:
model_endpoint_type="openai",
model_endpoint=self.base_url,
context_window=model["max_context_length"],
handle=self.get_handle(model["id"])
)
)

Expand Down Expand Up @@ -250,6 +259,7 @@ def list_llm_models(self) -> List[LLMConfig]:
model_endpoint=self.base_url,
model_wrapper=self.default_prompt_formatter,
context_window=context_window,
handle=self.get_handle(model["name"])
)
)
return configs
Expand Down Expand Up @@ -325,6 +335,7 @@ def list_embedding_models(self) -> List[EmbeddingConfig]:
embedding_endpoint=self.base_url,
embedding_dim=embedding_dim,
embedding_chunk_size=300,
handle=self.get_handle(model["name"])
)
)
return configs
Expand All @@ -345,7 +356,7 @@ def list_llm_models(self) -> List[LLMConfig]:
continue
configs.append(
LLMConfig(
model=model["id"], model_endpoint_type="groq", model_endpoint=self.base_url, context_window=model["context_window"]
model=model["id"], model_endpoint_type="groq", model_endpoint=self.base_url, context_window=model["context_window"], handle=self.get_handle(model["id"])
)
)
return configs
Expand Down Expand Up @@ -413,6 +424,7 @@ def list_llm_models(self) -> List[LLMConfig]:
model_endpoint=self.base_url,
model_wrapper=self.default_prompt_formatter,
context_window=context_window_size,
handle=self.get_handle(model_name)
)
)

Expand Down Expand Up @@ -493,6 +505,7 @@ def list_llm_models(self):
model_endpoint_type="google_ai",
model_endpoint=self.base_url,
context_window=self.get_model_context_window(model),
handle=self.get_handle(model)
)
)
return configs
Expand All @@ -516,6 +529,7 @@ def list_embedding_models(self):
embedding_endpoint=self.base_url,
embedding_dim=768,
embedding_chunk_size=300, # NOTE: max is 2048
handle=self.get_handle(model)
)
)
return configs
Expand Down Expand Up @@ -556,7 +570,7 @@ def list_llm_models(self) -> List[LLMConfig]:
context_window_size = self.get_model_context_window(model_name)
model_endpoint = get_azure_chat_completions_endpoint(self.base_url, model_name, self.api_version)
configs.append(
LLMConfig(model=model_name, model_endpoint_type="azure", model_endpoint=model_endpoint, context_window=context_window_size)
LLMConfig(model=model_name, model_endpoint_type="azure", model_endpoint=model_endpoint, context_window=context_window_size), handle=self.get_handle(model_name)
)
return configs

Expand All @@ -577,6 +591,7 @@ def list_embedding_models(self) -> List[EmbeddingConfig]:
embedding_endpoint=model_endpoint,
embedding_dim=768,
embedding_chunk_size=300, # NOTE: max is 2048
handle=self.get_handle(model_name)
)
)
return configs
Expand Down Expand Up @@ -610,6 +625,7 @@ def list_llm_models(self) -> List[LLMConfig]:
model_endpoint_type="openai",
model_endpoint=self.base_url,
context_window=model["max_model_len"],
handle=self.get_handle(model["id"])
)
)
return configs
Expand Down Expand Up @@ -642,6 +658,7 @@ def list_llm_models(self) -> List[LLMConfig]:
model_endpoint=self.base_url,
model_wrapper=self.default_prompt_formatter,
context_window=model["max_model_len"],
handle=self.get_handle(model["id"])
)
)
return configs
Expand Down
1 change: 1 addition & 0 deletions letta/schemas/embedding_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class EmbeddingConfig(BaseModel):
embedding_model: str = Field(..., description="The model for the embedding.")
embedding_dim: int = Field(..., description="The dimension of the embedding.")
embedding_chunk_size: Optional[int] = Field(300, description="The chunk size of the embedding.")
handle: str = Field(..., description="The handle for this config, in the format provider/model-name.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we should make this optional? there are a lot of tests which directly create an LLMConfig


# azure only
azure_endpoint: Optional[str] = Field(None, description="The Azure endpoint for the model.")
Expand Down
1 change: 1 addition & 0 deletions letta/schemas/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class LLMConfig(BaseModel):
True,
description="Puts 'inner_thoughts' as a kwarg in the function call if this is set to True. This helps with function calling performance and also the generation of inner thoughts.",
)
handle: str = Field(..., description="The handle for this config, in the format provider/model-name.")

# FIXME hack to silence pydantic protected namespace warning
model_config = ConfigDict(protected_namespaces=())
Expand Down
Loading