Skip to content

Commit 4f0982d

Browse files
committed
OpenAI operator- Changed generation of embeddings, moved validation of input text to the constructor of the operator, changed value of invalid value to ValueError instead of AirflowException
1 parent c986a93 commit 4f0982d

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

airflow/providers/openai/operators/openai.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from functools import cached_property
2121
from typing import TYPE_CHECKING, Any, Sequence
2222

23-
from airflow.exceptions import AirflowException
2423
from airflow.models import BaseOperator
2524
from airflow.providers.openai.hooks.openai import OpenAIHook
2625

@@ -60,18 +59,18 @@ def __init__(
6059
self.input_text = input_text
6160
self.model = model
6261
self.embedding_kwargs = embedding_kwargs or {}
62+
if not self.input_text or not isinstance(self.input_text, (str, list)):
63+
raise ValueError(
64+
"The 'input_text' must be a non-empty string, list of strings, list of integers, or list of lists of integers."
65+
)
6366

6467
@cached_property
6568
def hook(self) -> OpenAIHook:
6669
"""Return an instance of the OpenAIHook."""
6770
return OpenAIHook(conn_id=self.conn_id)
6871

6972
def execute(self, context: Context) -> list[float]:
70-
if not self.input_text or not isinstance(self.input_text, (str, list)):
71-
raise AirflowException(
72-
"The 'input_text' must be a non-empty string, list of strings, list of integers, or list of lists of integers."
73-
)
7473
self.log.info("Generating embeddings for the input text of length: %d", len(self.input_text))
7574
embeddings = self.hook.create_embeddings(self.input_text, model=self.model, **self.embedding_kwargs)
76-
self.log.info("Generated embeddings: %s", embeddings)
75+
self.log.info("Generated embeddings for %d items", len(embeddings))
7776
return embeddings

tests/providers/openai/operators/test_openai.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
import pytest
2222

23-
from airflow.exceptions import AirflowException
2423
from airflow.providers.openai.operators.openai import OpenAIEmbeddingOperator
2524
from airflow.utils.context import Context
2625

@@ -40,7 +39,7 @@ def test_execute_with_input_text():
4039

4140

4241
def test_execute_with_invalid_input_empty_string():
43-
with pytest.raises(AirflowException):
42+
with pytest.raises(ValueError):
4443
operator = OpenAIEmbeddingOperator(
4544
task_id="TaskId", conn_id="test_conn_id", model="test_model", input_text=""
4645
)
@@ -49,7 +48,7 @@ def test_execute_with_invalid_input_empty_string():
4948

5049

5150
def test_execute_with_invalid_input_none():
52-
with pytest.raises(AirflowException):
51+
with pytest.raises(ValueError):
5352
operator = OpenAIEmbeddingOperator(
5453
task_id="TaskId", conn_id="test_conn_id", model="test_model", input_text=None
5554
)
@@ -58,7 +57,7 @@ def test_execute_with_invalid_input_none():
5857

5958

6059
def test_execute_with_invalid_input_wrong_type():
61-
with pytest.raises(AirflowException):
60+
with pytest.raises(ValueError):
6261
operator = OpenAIEmbeddingOperator(
6362
task_id="TaskId", conn_id="test_conn_id", model="test_model", input_text=123
6463
)

0 commit comments

Comments
 (0)