Skip to content

Commit

Permalink
Fix get_models() and get_async_models() duplicates bug
Browse files Browse the repository at this point in the history
Closes #667, refs #640
  • Loading branch information
simonw committed Dec 5, 2024
1 parent e78fea1 commit b6be09a
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 2 deletions.
6 changes: 4 additions & 2 deletions llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,14 @@ class UnknownModelError(KeyError):

def get_models() -> List[Model]:
"Get all registered models"
return [model for model in get_model_aliases().values()]
models_with_aliases = get_models_with_aliases()
return [mwa.model for mwa in models_with_aliases if mwa.model]


def get_async_models() -> List[AsyncModel]:
"Get all registered async models"
return [model for model in get_async_model_aliases().values()]
models_with_aliases = get_models_with_aliases()
return [mwa.async_model for mwa in models_with_aliases if mwa.async_model]


def get_async_model(name: Optional[str] = None) -> AsyncModel:
Expand Down
4 changes: 4 additions & 0 deletions llm/default_plugins/openai_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ def register_models(register):
aliases=("gpt-4-turbo-preview", "4-turbo", "4t"),
)
# o1
# register(
# Chat("o1", can_stream=False, allows_system_prompt=False, vision=True),
# AsyncChat("o1", can_stream=False, allows_system_prompt=False, vision=True),
# )
register(
Chat("o1-preview", can_stream=False, allows_system_prompt=False),
AsyncChat("o1-preview", can_stream=False, allows_system_prompt=False),
Expand Down
3 changes: 3 additions & 0 deletions tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,9 @@ def test_get_models():
assert all(isinstance(model, llm.Model) for model in models)
model_ids = [model.model_id for model in models]
assert "gpt-4o-mini" in model_ids
# Ensure no model_ids are duplicated
# https://github.com/simonw/llm/issues/667
assert len(model_ids) == len(set(model_ids))


def test_get_async_models():
Expand Down

0 comments on commit b6be09a

Please sign in to comment.