Skip to content

Commit c06b11d

Browse files
authored
feat(llm): integrate LLMClient with litellm (#35)
1 parent 59e08ac commit c06b11d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+666
-633
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ This is a basic implementation of a db-ally view for an example HR application,
3131

3232
```python
3333
from dbally import decorators, SqlAlchemyBaseView, create_collection
34-
from dbally.llm_client.openai_client import OpenAIClient
34+
from dbally.llms.litellm import LiteLLM
3535
from sqlalchemy import create_engine
3636

3737
class CandidateView(SqlAlchemyBaseView):
@@ -53,7 +53,7 @@ class CandidateView(SqlAlchemyBaseView):
5353
return Candidate.country == country
5454

5555
engine = create_engine('sqlite:///candidates.db')
56-
llm = OpenAIClient(model_name="gpt-3.5-turbo")
56+
llm = LiteLLM(model_name="gpt-3.5-turbo")
5757
my_collection = create_collection("collection_name", llm)
5858
my_collection.add(CandidateView, lambda: CandidateView(engine))
5959

@@ -82,12 +82,12 @@ pip install dbally
8282

8383
Additionally, you can install one of our extensions to use specific features.
8484

85-
* `dbally[openai]`: Use [OpenAI's models](https://platform.openai.com/docs/models)
85+
* `dbally[litellm]`: Use [100+ LLMs](https://docs.litellm.ai/docs/providers)
8686
* `dbally[faiss]`: Use [Faiss](https://github.com/facebookresearch/faiss) indexes for similarity search
8787
* `dbally[langsmith]`: Use [LangSmith](https://www.langchain.com/langsmith) for query tracking
8888

8989
```bash
90-
pip install dbally[openai,faiss,langsmith]
90+
pip install dbally[litellm,faiss,langsmith]
9191
```
9292

9393
## License

benchmark/dbally_benchmark/e2e_benchmark.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import dbally
2424
from dbally.collection import Collection
2525
from dbally.iql_generator.iql_prompt_template import default_iql_template
26-
from dbally.llm_client.openai_client import OpenAIClient
26+
from dbally.llms.litellm import LiteLLM
2727
from dbally.utils.errors import NoViewFoundError, UnsupportedQueryError
2828
from dbally.view_selection.view_selector_prompt_template import default_view_selector_template
2929

@@ -82,12 +82,12 @@ async def evaluate(cfg: DictConfig) -> Any:
8282

8383
engine = create_engine(benchmark_cfg.pg_connection_string + f"/{cfg.db_name}")
8484

85-
llm_client = OpenAIClient(
85+
llm = LiteLLM(
8686
model_name="gpt-4",
8787
api_key=benchmark_cfg.openai_api_key,
8888
)
8989

90-
db = dbally.create_collection(cfg.db_name, llm_client)
90+
db = dbally.create_collection(cfg.db_name, llm)
9191

9292
for view_name in cfg.view_names:
9393
view = VIEW_REGISTRY[ViewName(view_name)]

benchmark/dbally_benchmark/iql_benchmark.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from dbally.audit.event_tracker import EventTracker
2323
from dbally.iql_generator.iql_generator import IQLGenerator
2424
from dbally.iql_generator.iql_prompt_template import default_iql_template
25-
from dbally.llm_client.openai_client import OpenAIClient
25+
from dbally.llms.litellm import LiteLLM
2626
from dbally.utils.errors import UnsupportedQueryError
2727
from dbally.views.structured import BaseStructuredView
2828

@@ -96,13 +96,13 @@ async def evaluate(cfg: DictConfig) -> Any:
9696
view = VIEW_REGISTRY[ViewName(view_name)](engine)
9797

9898
if "gpt" in cfg.model_name:
99-
llm_client = OpenAIClient(
99+
llm = LiteLLM(
100100
model_name=cfg.model_name,
101101
api_key=benchmark_cfg.openai_api_key,
102102
)
103103
else:
104104
raise ValueError("Only OpenAI's GPT models are supported for now.")
105-
iql_generator = IQLGenerator(llm_client=llm_client)
105+
iql_generator = IQLGenerator(llm=llm)
106106

107107
run = None
108108
if cfg.neptune.log:

benchmark/dbally_benchmark/text2sql_benchmark.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@
2121
from sqlalchemy import create_engine
2222

2323
from dbally.audit.event_tracker import EventTracker
24-
from dbally.llm_client.base import LLMClient
25-
from dbally.llm_client.openai_client import OpenAIClient
24+
from dbally.llms.litellm import LiteLLM
2625

2726

2827
def _load_db_schema(db_name: str, encoding: Optional[str] = None) -> str:
@@ -35,12 +34,12 @@ def _load_db_schema(db_name: str, encoding: Optional[str] = None) -> str:
3534
return db_schema
3635

3736

38-
async def _run_text2sql_for_single_example(example: BIRDExample, llm_client: LLMClient) -> Text2SQLResult:
37+
async def _run_text2sql_for_single_example(example: BIRDExample, llm: LiteLLM) -> Text2SQLResult:
3938
event_tracker = EventTracker()
4039

4140
db_schema = _load_db_schema(example.db_id)
4241

43-
response = await llm_client.text_generation(
42+
response = await llm.generate_text(
4443
TEXT2SQL_PROMPT_TEMPLATE, {"schema": db_schema, "question": example.question}, event_tracker=event_tracker
4544
)
4645

@@ -49,13 +48,13 @@ async def _run_text2sql_for_single_example(example: BIRDExample, llm_client: LLM
4948
)
5049

5150

52-
async def run_text2sql_for_dataset(dataset: BIRDDataset, llm_client: LLMClient) -> List[Text2SQLResult]:
51+
async def run_text2sql_for_dataset(dataset: BIRDDataset, llm: LiteLLM) -> List[Text2SQLResult]:
5352
"""
5453
Transforms questions into SQL queries using a Text2SQL model.
5554
5655
Args:
5756
dataset: The dataset containing questions to be transformed into SQL queries.
58-
llm_client: LLM client.
57+
llm: LLM client.
5958
6059
Returns:
6160
A list of Text2SQLResult objects representing the predictions.
@@ -64,9 +63,7 @@ async def run_text2sql_for_dataset(dataset: BIRDDataset, llm_client: LLMClient)
6463
results: List[Text2SQLResult] = []
6564

6665
for group in batch(dataset, 5):
67-
current_results = await asyncio.gather(
68-
*[_run_text2sql_for_single_example(example, llm_client) for example in group]
69-
)
66+
current_results = await asyncio.gather(*[_run_text2sql_for_single_example(example, llm) for example in group])
7067
results = [*current_results, *results]
7168

7269
return results
@@ -88,7 +85,7 @@ async def evaluate(cfg: DictConfig) -> Any:
8885
engine = create_engine(benchmark_cfg.pg_connection_string + f"/{cfg.db_name}")
8986

9087
if "gpt" in cfg.model_name:
91-
llm_client = OpenAIClient(
88+
llm = LiteLLM(
9289
model_name=cfg.model_name,
9390
api_key=benchmark_cfg.openai_api_key,
9491
)
@@ -112,7 +109,7 @@ async def evaluate(cfg: DictConfig) -> Any:
112109
evaluation_dataset = BIRDDataset.from_json_file(
113110
Path(cfg.dataset_path), difficulty_levels=cfg.get("difficulty_levels")
114111
)
115-
text2sql_results = await run_text2sql_for_dataset(dataset=evaluation_dataset, llm_client=llm_client)
112+
text2sql_results = await run_text2sql_for_dataset(dataset=evaluation_dataset, llm=llm)
116113

117114
with open(output_dir / results_file_name, "w", encoding="utf-8") as outfile:
118115
json.dump([result.model_dump() for result in text2sql_results], outfile, indent=4)

docs/concepts/collections.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33
At its core, a collection groups together multiple [views](views.md). Once you've defined your views, the next step is to register them within a collection. Here's how you might do it:
44

55
```python
6-
my_collection = dbally.create_collection("collection_name", llm_client=OpenAIClient())
6+
my_collection = dbally.create_collection("collection_name", llm=LiteLLM())
77
my_collection.add(ExampleView)
88
my_collection.add(RecipesView)
99
```
1010

1111
Sometimes, view classes might need certain arguments when they're instantiated. In these instances, you'll want to register your view with a builder function that takes care of supplying these arguments. For instance, with views that rely on SQLAlchemy, you'll typically need to pass a database engine object like so:
1212

1313
```python
14-
my_collection = dbally.create_collection("collection_name", llm_client=OpenAIClient())
14+
my_collection = dbally.create_collection("collection_name", llm=LiteLLM())
1515
engine = sqlalchemy.create_engine("sqlite://")
1616
my_collection.add(ExampleView, lambda: ExampleView(engine))
1717
my_collection.add(RecipesView, lambda: RecipesView(engine))

docs/concepts/freeform_views.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
Freeform views are a type of [view](views.md) that provides a way for developers using db-ally to define what they need from the LLM without requiring a fixed response structure. This flexibility is beneficial when the data structure is unknown beforehand or when potential queries are too diverse to be covered by a structured view. Though freeform views offer more flexibility than structured views, they are less predictable, efficient, and secure, and may be more challenging to integrate with other systems. For these reasons, we recommend using [structured views](./structured_views.md) when possible.
44

5-
Unlike structured views, which define a response format and a set of operations the LLM may use in response to natural language queries, freeform views only have one task - to respond directly to natural language queries with data from the datasource. They accomplish this by implementing the [`ask`][dbally.views.base.BaseView] method. This method takes a natural language query as input and returns a response. The method also has access to the LLM model (via the `llm_client` attribute), which is typically used to retrieve the correct data from the source (for example, by generating a source-specific query string). To learn more about implementing freeform views, refer to the [How to: Custom Freeform Views](../how-to/custom_freeform_views.md) guide.
5+
Unlike structured views, which define a response format and a set of operations the LLM may use in response to natural language queries, freeform views only have one task - to respond directly to natural language queries with data from the datasource. They accomplish this by implementing the [`ask`][dbally.views.base.BaseView] method. This method takes a natural language query as input and returns a response. The method also has access to the LLM model (via the `llm` attribute), which is typically used to retrieve the correct data from the source (for example, by generating a source-specific query string). To learn more about implementing freeform views, refer to the [How to: Custom Freeform Views](../how-to/custom_freeform_views.md) guide.
66

77
## Security
88

docs/how-to/create_custom_event_handler.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,11 +117,11 @@ To use our event handler, we need to pass it to the collection when creating it:
117117

118118
```python
119119
import dbally
120-
from dbally.llm_client.openai_client import OpenAIClient
120+
from dbally.llms.litellm import LiteLLM
121121

122122
my_collection = bally.create_collection(
123123
"collection_name",
124-
llm_client=OpenAIClient(),
124+
llm=LiteLLM(),
125125
event_handlers=[FileEventHandler()],
126126
)
127127
```

docs/how-to/custom_views.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,10 +219,10 @@ Finally, we can use the `CandidatesView` just like any other view in db-ally. We
219219
```python
220220
import asyncio
221221
import dbally
222-
from dbally.llm_client.openai_client import OpenAIClient
222+
from dbally.llms.litellm import LiteLLM
223223

224224
async def main():
225-
llm = OpenAIClient(model_name="gpt-3.5-turbo")
225+
llm = LiteLLM(model_name="gpt-3.5-turbo")
226226
collection = dbally.create_collection("recruitment", llm)
227227
collection.add(CandidateView)
228228

docs/how-to/custom_views_code.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler
1111
from dbally.iql import IQLQuery, syntax
1212
from dbally.data_models.execution_result import ViewExecutionResult
13-
from dbally.llm_client.openai_client import OpenAIClient
13+
from dbally.llms.litellm import LiteLLM
1414

1515
@dataclass
1616
class Candidate:
@@ -99,7 +99,7 @@ def from_country(self, country: str) -> Callable[[Candidate], bool]:
9999
return lambda x: x.country == country
100100

101101
async def main():
102-
llm = OpenAIClient(model_name="gpt-3.5-turbo")
102+
llm = LiteLLM(model_name="gpt-3.5-turbo")
103103
event_handlers = [CLIEventHandler()]
104104
collection = dbally.create_collection("recruitment", llm, event_handlers=event_handlers)
105105
collection.add(CandidateView)

docs/how-to/log_runs_to_langsmith.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ from dbally.audit.event_handlers.langsmith_event_handler import LangSmithEventHa
2929
3030
my_collection = dbally.create_collection(
3131
"collection_name",
32-
llm_client=OpenAIClient(),
32+
llm=LiteLLM(),
3333
event_handlers=[LangSmithEventHandler(api_key="your_api_key")],
3434
)
3535
```

0 commit comments

Comments
 (0)