Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion airflow/providers/openai/hooks/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,10 @@ def _get_api_base(self) -> None | str:
return conn.host

def create_embeddings(
self, text: str | list[Any], model: str = "text-embedding-ada-002", **kwargs: Any
self,
text: str | list[str] | list[int] | list[list[int]],
model: str = "text-embedding-ada-002",
**kwargs: Any,
) -> list[float]:
"""Generate embeddings for the given text using the given model.

Expand Down
27 changes: 16 additions & 11 deletions airflow/providers/openai/operators/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,41 +31,46 @@ class OpenAIEmbeddingOperator(BaseOperator):
"""
Operator that accepts input text to generate OpenAI embeddings using the specified model.

:param conn_id: The OpenAI connection ID to use.
:param input_text: The text to generate OpenAI embeddings for. This can be a string, a list of strings,
a list of integers, or a list of lists of integers.
:param model: The OpenAI model to be used for generating the embeddings.
:param embedding_kwargs: Additional keyword arguments to pass to the OpenAI `create_embeddings` method.

.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:OpenAIEmbeddingOperator`

:param conn_id: The OpenAI connection.
:param input_text: The text to generate OpenAI embeddings on. Either input_text or input_callable
should be provided.
:param model: The OpenAI model to be used for generating the embeddings.
:param embedding_kwargs: For possible option check
.. seealso:: https://platform.openai.com/docs/api-reference/embeddings/create
For possible options for `embedding_kwargs`, see:
https://platform.openai.com/docs/api-reference/embeddings/create
"""

template_fields: Sequence[str] = ("input_text",)

def __init__(
self,
conn_id: str,
input_text: str | list[Any],
input_text: str | list[str] | list[int] | list[list[int]],
model: str = "text-embedding-ada-002",
embedding_kwargs: dict | None = None,
**kwargs: Any,
):
self.embedding_kwargs = embedding_kwargs or {}
super().__init__(**kwargs)
self.conn_id = conn_id
self.input_text = input_text
self.model = model
self.embedding_kwargs = embedding_kwargs or {}

@cached_property
def hook(self) -> OpenAIHook:
"""Return an instance of the OpenAIHook."""
return OpenAIHook(conn_id=self.conn_id)

def execute(self, context: Context) -> list[float]:
self.log.info("Input text: %s", self.input_text)
if not self.input_text or not isinstance(self.input_text, (str, list)):
raise ValueError(
"The 'input_text' must be a non-empty string, list of strings, list of integers, or list of lists of integers."
)
self.log.info("Generating embeddings for the input text of length: %d", len(self.input_text))
embeddings = self.hook.create_embeddings(self.input_text, model=self.model, **self.embedding_kwargs)
self.log.info("Embeddings: %s", embeddings)
self.log.info("Generated embeddings for %d items", len(embeddings))
return embeddings
12 changes: 12 additions & 0 deletions tests/providers/openai/operators/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

from unittest.mock import Mock

import pytest

from airflow.providers.openai.operators.openai import OpenAIEmbeddingOperator
from airflow.utils.context import Context

Expand All @@ -34,3 +36,13 @@ def test_execute_with_input_text():
embeddings = operator.execute(context)

assert embeddings == [1.0, 2.0, 3.0]


@pytest.mark.parametrize("invalid_input", ["", None, 123])
def test_execute_with_invalid_input(invalid_input):
with pytest.raises(ValueError):
operator = OpenAIEmbeddingOperator(
task_id="TaskId", conn_id="test_conn_id", model="test_model", input_text=invalid_input
)
context = Context()
operator.execute(context)