Skip to content

Commit f6839a8

Browse files
authored
Add integration for MLflow AI Gateway (#7113)
<!-- Thank you for contributing to LangChain! Replace this comment with: - Description: a description of the change, - Issue: the issue # it fixes (if applicable), - Dependencies: any dependencies required for this change, - Tag maintainer: for a quicker response, tag the relevant maintainer (see below), - Twitter handle: we announce bigger features on Twitter. If your PR gets announced and you'd like a mention, we'll gladly shout you out! If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. Maintainer responsibilities: - General / Misc / if you don't know who to tag: @baskaryan - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev - Models / Prompts: @hwchase17, @baskaryan - Memory: @hwchase17 - Agents / Tools / Toolkits: @hinthornw - Tracing / Callbacks: @agola11 - Async: @agola11 If no one reviews your PR within a few days, feel free to @-mention the same people again. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md --> - Adds integration for MLflow AI Gateway (this will be shipped in MLflow 2.5 this week). Manual testing: ```sh # Move to mlflow repo cd /path/to/mlflow # install langchain pip install git+https://github.com/harupy/langchain.git@gateway-integration # launch gateway service mlflow gateway start --config-path examples/gateway/openai/config.yaml # Then, run the examples in this PR ```
1 parent 6792a35 commit f6839a8

File tree

6 files changed

+259
-6
lines changed

6 files changed

+259
-6
lines changed

docs/extras/ecosystem/integrations/databricks.md

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,28 @@ The [Databricks](https://www.databricks.com/) Lakehouse Platform unifies data, a
66
Databricks embraces the LangChain ecosystem in various ways:
77

88
1. Databricks connector for the SQLDatabase Chain: SQLDatabase.from_databricks() provides an easy way to query your data on Databricks through LangChain
9-
2. Databricks-managed MLflow integrates with LangChain: Tracking and serving LangChain applications with fewer steps
10-
3. Databricks as an LLM provider: Deploy your fine-tuned LLMs on Databricks via serving endpoints or cluster driver proxy apps, and query it as langchain.llms.Databricks
11-
4. Databricks Dolly: Databricks open-sourced Dolly which allows for commercial use, and can be accessed through the Hugging Face Hub
9+
2. Databricks MLflow integrates with LangChain: Tracking and serving LangChain applications with fewer steps
10+
3. Databricks MLflow AI Gateway
11+
4. Databricks as an LLM provider: Deploy your fine-tuned LLMs on Databricks via serving endpoints or cluster driver proxy apps, and query it as langchain.llms.Databricks
12+
5. Databricks Dolly: Databricks open-sourced Dolly which allows for commercial use, and can be accessed through the Hugging Face Hub
1213

1314
Databricks connector for the SQLDatabase Chain
1415
----------------------------------------------
1516
You can connect to [Databricks runtimes](https://docs.databricks.com/runtime/index.html) and [Databricks SQL](https://www.databricks.com/product/databricks-sql) using the SQLDatabase wrapper of LangChain. See the notebook [Connect to Databricks](/docs/ecosystem/integrations/databricks/databricks.html) for details.
1617

17-
Databricks-managed MLflow integrates with LangChain
18-
---------------------------------------------------
18+
Databricks MLflow integrates with LangChain
19+
-------------------------------------------
1920

2021
MLflow is an open source platform to manage the ML lifecycle, including experimentation, reproducibility, deployment, and a central model registry. See the notebook [MLflow Callback Handler](/docs/ecosystem/integrations/mlflow_tracking.ipynb) for details about MLflow's integration with LangChain.
2122

2223
Databricks provides a fully managed and hosted version of MLflow integrated with enterprise security features, high availability, and other Databricks workspace features such as experiment and run management and notebook revision capture. MLflow on Databricks offers an integrated experience for tracking and securing machine learning model training runs and running machine learning projects. See [MLflow guide](https://docs.databricks.com/mlflow/index.html) for more details.
2324

24-
Databricks-managed MLflow makes it more convenient to develop LangChain applications on Databricks. For MLflow tracking, you don't need to set the tracking uri. For MLflow Model Serving, you can save LangChain Chains in the MLflow langchain flavor, and then register and serve the Chain with a few clicks on Databricks, with credentials securely managed by MLflow Model Serving.
25+
Databricks MLflow makes it more convenient to develop LangChain applications on Databricks. For MLflow tracking, you don't need to set the tracking uri. For MLflow Model Serving, you can save LangChain Chains in the MLflow langchain flavor, and then register and serve the Chain with a few clicks on Databricks, with credentials securely managed by MLflow Model Serving.
26+
27+
Databricks MLflow AI Gateway
28+
----------------------------
29+
30+
See [MLflow AI Gateway](/docs/ecosystem/integrations/mlflow_ai_gateway).
2531

2632
Databricks as an LLM provider
2733
-----------------------------
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# MLflow AI Gateway
2+
3+
The MLflow AI Gateway service is a powerful tool designed to streamline the usage and management of various large language model (LLM) providers, such as OpenAI and Anthropic, within an organization. It offers a high-level interface that simplifies the interaction with these services by providing a unified endpoint to handle specific LLM related requests. See [the MLflow AI Gateway documentation](https://mlflow.org/docs/latest/gateway/index.html) for more details.
4+
5+
## Installation and Setup
6+
7+
Install `mlflow` with MLflow AI Gateway dependencies:
8+
9+
```sh
10+
pip install 'mlflow[gateway]'
11+
```
12+
13+
Set the OpenAI API key as an environment variable:
14+
15+
```sh
16+
export OPENAI_API_KEY=...
17+
```
18+
19+
Create a configuration file:
20+
21+
```yaml
22+
routes:
23+
- name: completions
24+
type: llm/v1/completions
25+
model:
26+
provider: openai
27+
name: text-davinci-003
28+
config:
29+
openai_api_key: $OPENAI_API_KEY
30+
31+
- name: embeddings
32+
type: llm/v1/embeddings
33+
model:
34+
provider: openai
35+
name: text-embedding-ada-002
36+
config:
37+
openai_api_key: $OPENAI_API_KEY
38+
```
39+
40+
Start the Gateway server:
41+
42+
```sh
43+
mlflow gateway start --config-path /path/to/config.yaml
44+
```
45+
46+
## Completions Example
47+
48+
```python
49+
import mlflow
50+
from langchain import LLMChain, PromptTemplate
51+
from langchain.llms import MlflowAIGateway
52+
53+
gateway = MlflowAIGateway(
54+
gateway_uri="http://127.0.0.1:5000",
55+
route="completions",
56+
params={
57+
"temperature": 0.0,
58+
"top_p": 0.1,
59+
},
60+
)
61+
62+
llm_chain = LLMChain(
63+
llm=gateway,
64+
prompt=PromptTemplate(
65+
input_variables=["adjective"],
66+
template="Tell me a {adjective} joke",
67+
),
68+
)
69+
result = llm_chain.run(adjective="funny")
70+
print(result)
71+
72+
with mlflow.start_run():
73+
model_info = mlflow.langchain.log_model(chain, "model")
74+
75+
model = mlflow.pyfunc.load_model(model_info.model_uri)
76+
print(model.predict([{"adjective": "funny"}]))
77+
```
78+
79+
## Embeddings Example
80+
81+
```python
82+
from langchain.embeddings import MlflowAIGatewayEmbeddings
83+
84+
embeddings = MlflowAIGatewayEmbeddings(
85+
gateway_uri="http://127.0.0.1:5000",
86+
route="embeddings",
87+
)
88+
89+
print(embeddings.embed_query("hello"))
90+
print(embeddings.embed_documents(["hello"]))
91+
```
92+
93+
## Databricks MLflow AI Gateway
94+
95+
Databricks MLflow AI Gateway is in private preview.
96+
Please contact a Databricks representative to enroll in the preview.
97+
98+
```python
99+
from langchain import LLMChain, PromptTemplate
100+
from langchain.llms import MlflowAIGateway
101+
102+
gateway = MlflowAIGateway(
103+
gateway_uri="databricks",
104+
route="completions",
105+
)
106+
107+
llm_chain = LLMChain(
108+
llm=gateway,
109+
prompt=PromptTemplate(
110+
input_variables=["adjective"],
111+
template="Tell me a {adjective} joke",
112+
),
113+
)
114+
result = llm_chain.run(adjective="funny")
115+
print(result)
116+
```

langchain/embeddings/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from langchain.embeddings.jina import JinaEmbeddings
2525
from langchain.embeddings.llamacpp import LlamaCppEmbeddings
2626
from langchain.embeddings.minimax import MiniMaxEmbeddings
27+
from langchain.embeddings.mlflow_gateway import MlflowAIGatewayEmbeddings
2728
from langchain.embeddings.modelscope_hub import ModelScopeEmbeddings
2829
from langchain.embeddings.mosaicml import MosaicMLInstructorEmbeddings
2930
from langchain.embeddings.octoai_embeddings import OctoAIEmbeddings
@@ -50,6 +51,7 @@
5051
"JinaEmbeddings",
5152
"LlamaCppEmbeddings",
5253
"HuggingFaceHubEmbeddings",
54+
"MlflowAIGatewayEmbeddings",
5355
"ModelScopeEmbeddings",
5456
"TensorflowHubEmbeddings",
5557
"SagemakerEndpointEmbeddings",
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from __future__ import annotations
2+
3+
from typing import Any, Iterator, List, Optional
4+
5+
from pydantic import BaseModel
6+
7+
from langchain.embeddings.base import Embeddings
8+
9+
10+
def _chunk(texts: List[str], size: int) -> Iterator[List[str]]:
11+
for i in range(0, len(texts), size):
12+
yield texts[i : i + size]
13+
14+
15+
class MlflowAIGatewayEmbeddings(Embeddings, BaseModel):
16+
route: str
17+
gateway_uri: Optional[str] = None
18+
19+
def __init__(self, **kwargs: Any):
20+
try:
21+
import mlflow.gateway
22+
except ImportError as e:
23+
raise ImportError(
24+
"Could not import `mlflow.gateway` module. "
25+
"Please install it with `pip install mlflow[gateway]`."
26+
) from e
27+
28+
super().__init__(**kwargs)
29+
if self.gateway_uri:
30+
mlflow.gateway.set_gateway_uri(self.gateway_uri)
31+
32+
def _query(self, texts: List[str]) -> List[List[float]]:
33+
try:
34+
import mlflow.gateway
35+
except ImportError as e:
36+
raise ImportError(
37+
"Could not import `mlflow.gateway` module. "
38+
"Please install it with `pip install mlflow[gateway]`."
39+
) from e
40+
41+
embeddings = []
42+
for txt in _chunk(texts, 20):
43+
resp = mlflow.gateway.query(self.route, data={"text": txt})
44+
embeddings.append(resp["embeddings"])
45+
return embeddings
46+
47+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
48+
return self._query(texts)
49+
50+
def embed_query(self, text: str) -> List[float]:
51+
return self._query([text])[0]

langchain/llms/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from langchain.llms.koboldai import KoboldApiLLM
3434
from langchain.llms.llamacpp import LlamaCpp
3535
from langchain.llms.manifest import ManifestWrapper
36+
from langchain.llms.mlflow_ai_gateway import MlflowAIGateway
3637
from langchain.llms.modal import Modal
3738
from langchain.llms.mosaicml import MosaicML
3839
from langchain.llms.nlpcloud import NLPCloud
@@ -89,6 +90,7 @@
8990
"LlamaCpp",
9091
"TextGen",
9192
"ManifestWrapper",
93+
"MlflowAIGateway",
9294
"Modal",
9395
"MosaicML",
9496
"NLPCloud",
@@ -146,6 +148,7 @@
146148
"koboldai": KoboldApiLLM,
147149
"llamacpp": LlamaCpp,
148150
"textgen": TextGen,
151+
"mlflow-gateway": MlflowAIGateway,
149152
"modal": Modal,
150153
"mosaic": MosaicML,
151154
"nlpcloud": NLPCloud,
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from __future__ import annotations
2+
3+
from typing import Any, Dict, List, Mapping, Optional
4+
5+
from pydantic import BaseModel, Extra
6+
7+
from langchain.callbacks.manager import CallbackManagerForLLMRun
8+
from langchain.llms.base import LLM
9+
10+
11+
class Params(BaseModel, extra=Extra.allow):
12+
temperature: float = 0.0
13+
candidate_count: int = 1
14+
stop: Optional[List[str]] = None
15+
max_tokens: Optional[int] = None
16+
17+
18+
class MlflowAIGateway(LLM):
19+
route: str
20+
gateway_uri: Optional[str] = None
21+
params: Optional[Params] = None
22+
23+
def __init__(self, **kwargs: Any):
24+
try:
25+
import mlflow.gateway
26+
except ImportError as e:
27+
raise ImportError(
28+
"Could not import `mlflow.gateway` module. "
29+
"Please install it with `pip install mlflow[gateway]`."
30+
) from e
31+
32+
super().__init__(**kwargs)
33+
if self.gateway_uri:
34+
mlflow.gateway.set_gateway_uri(self.gateway_uri)
35+
36+
@property
37+
def _default_params(self) -> Dict[str, Any]:
38+
params: Dict[str, Any] = {
39+
"gateway_uri": self.gateway_uri,
40+
"route": self.route,
41+
**(self.params.dict() if self.params else {}),
42+
}
43+
return params
44+
45+
@property
46+
def _identifying_params(self) -> Mapping[str, Any]:
47+
return self._default_params
48+
49+
def _call(
50+
self,
51+
prompt: str,
52+
stop: Optional[List[str]] = None,
53+
run_manager: Optional[CallbackManagerForLLMRun] = None,
54+
**kwargs: Any,
55+
) -> str:
56+
try:
57+
import mlflow.gateway
58+
except ImportError as e:
59+
raise ImportError(
60+
"Could not import `mlflow.gateway` module. "
61+
"Please install it with `pip install mlflow[gateway]`."
62+
) from e
63+
64+
data: Dict[str, Any] = {
65+
"prompt": prompt,
66+
**(self.params.dict() if self.params else {}),
67+
}
68+
if s := (stop or (self.params.stop if self.params else None)):
69+
data["stop"] = s
70+
resp = mlflow.gateway.query(self.route, data=data)
71+
return resp["candidates"][0]["text"]
72+
73+
@property
74+
def _llm_type(self) -> str:
75+
return "mlflow-ai-gateway"

0 commit comments

Comments
 (0)