Skip to content

Commit

Permalink
feat(providers/openai): support batch api in hook/operator/trigger (a…
Browse files Browse the repository at this point in the history
…pache#41554)

* feat(providers/openai)
    * support batch api in hook/operator/trigger
    * add wait_for_completion to OpenAITriggerBatchOperator

---------

Co-authored-by: YungHsiu Chen <yunghsiu1994@gmail.com>
Co-authored-by: Wei Lee <weilee.rx@gmail.com>
  • Loading branch information
3 people authored Aug 22, 2024
1 parent 170b9ce commit 00e73e6
Show file tree
Hide file tree
Showing 13 changed files with 858 additions and 8 deletions.
28 changes: 28 additions & 0 deletions airflow/providers/openai/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from airflow.exceptions import AirflowException


class OpenAIBatchJobException(AirflowException):
"""Raise when OpenAI Batch Job fails to start AFTER processing the request."""


class OpenAIBatchTimeout(AirflowException):
"""Raise when OpenAI Batch Job times out."""
103 changes: 100 additions & 3 deletions airflow/providers/openai/hooks/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,16 @@

from __future__ import annotations

import time
from enum import Enum
from functools import cached_property
from typing import TYPE_CHECKING, Any, BinaryIO, Literal

from openai import OpenAI

if TYPE_CHECKING:
from openai.types import FileDeleted, FileObject
from openai.types.batch import Batch
from openai.types.beta import (
Assistant,
AssistantDeleted,
Expand All @@ -42,8 +45,29 @@
ChatCompletionToolMessageParam,
ChatCompletionUserMessageParam,
)

from airflow.hooks.base import BaseHook
from airflow.providers.openai.exceptions import OpenAIBatchJobException, OpenAIBatchTimeout


class BatchStatus(str, Enum):
"""Enum for the status of a batch."""

VALIDATING = "validating"
FAILED = "failed"
IN_PROGRESS = "in_progress"
FINALIZING = "finalizing"
COMPLETED = "completed"
EXPIRED = "expired"
CANCELLING = "cancelling"
CANCELLED = "cancelled"

def __str__(self) -> str:
return str(self.value)

@classmethod
def is_in_progress(cls, status: str) -> bool:
"""Check if the batch status is in progress."""
return status in (cls.VALIDATING, cls.IN_PROGRESS, cls.FINALIZING)


class OpenAIHook(BaseHook):
Expand Down Expand Up @@ -288,13 +312,13 @@ def create_embeddings(
embeddings: list[float] = response.data[0].embedding
return embeddings

def upload_file(self, file: str, purpose: Literal["fine-tune", "assistants"]) -> FileObject:
def upload_file(self, file: str, purpose: Literal["fine-tune", "assistants", "batch"]) -> FileObject:
"""
Upload a file that can be used across various endpoints. The size of all the files uploaded by one organization can be up to 100 GB.
:param file: The File object (not file name) to be uploaded.
:param purpose: The intended purpose of the uploaded file. Use "fine-tune" for
Fine-tuning and "assistants" for Assistants and Messages.
Fine-tuning, "assistants" for Assistants and Messages, and "batch" for Batch API.
"""
with open(file, "rb") as file_stream:
file_object = self.conn.files.create(file=file_stream, purpose=purpose)
Expand Down Expand Up @@ -393,3 +417,76 @@ def delete_vector_store_file(self, vector_store_id: str, file_id: str) -> Vector
"""
response = self.conn.beta.vector_stores.files.delete(vector_store_id=vector_store_id, file_id=file_id)
return response

def create_batch(
self,
file_id: str,
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"],
metadata: dict[str, str] | None = None,
completion_window: Literal["24h"] = "24h",
) -> Batch:
"""
Create a batch for a given model and files.
:param file_id: The ID of the file to be used for this batch.
:param endpoint: The endpoint to use for this batch. Allowed values include:
'/v1/chat/completions', '/v1/embeddings', '/v1/completions'.
:param metadata: A set of key-value pairs that can be attached to an object.
:param completion_window: The time window for the batch to complete. Default is 24 hours.
"""
batch = self.conn.batches.create(
input_file_id=file_id, endpoint=endpoint, metadata=metadata, completion_window=completion_window
)
return batch

def get_batch(self, batch_id: str) -> Batch:
"""
Get the status of a batch.
:param batch_id: The ID of the batch to get the status of.
"""
batch = self.conn.batches.retrieve(batch_id=batch_id)
return batch

def wait_for_batch(self, batch_id: str, wait_seconds: float = 3, timeout: float = 3600) -> None:
"""
Poll a batch to check if it finishes.
:param batch_id: Id of the Batch to wait for.
:param wait_seconds: Optional. Number of seconds between checks.
:param timeout: Optional. How many seconds wait for batch to be ready.
Used only if not ran in deferred operator.
"""
start = time.monotonic()
while True:
if start + timeout < time.monotonic():
self.cancel_batch(batch_id=batch_id)
raise OpenAIBatchTimeout(f"Timeout: OpenAI Batch {batch_id} is not ready after {timeout}s")
batch = self.get_batch(batch_id=batch_id)

if BatchStatus.is_in_progress(batch.status):
time.sleep(wait_seconds)
continue
if batch.status == BatchStatus.COMPLETED:
return
if batch.status == BatchStatus.FAILED:
raise OpenAIBatchJobException(f"Batch failed - \n{batch_id}")
elif batch.status in (BatchStatus.CANCELLED, BatchStatus.CANCELLING):
raise OpenAIBatchJobException(f"Batch failed - batch was cancelled:\n{batch_id}")
elif batch.status == BatchStatus.EXPIRED:
raise OpenAIBatchJobException(
f"Batch failed - batch couldn't be completed within the hour time window :\n{batch_id}"
)

raise OpenAIBatchJobException(
f"Batch failed - encountered unexpected status `{batch.status}` for batch_id `{batch_id}`"
)

def cancel_batch(self, batch_id: str) -> Batch:
"""
Cancel a batch.
:param batch_id: The ID of the batch to delete.
"""
batch = self.conn.batches.cancel(batch_id=batch_id)
return batch
94 changes: 93 additions & 1 deletion airflow/providers/openai/operators/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@

from __future__ import annotations

import time
from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence
from typing import TYPE_CHECKING, Any, Literal, Sequence

from airflow.configuration import conf
from airflow.models import BaseOperator
from airflow.providers.openai.exceptions import OpenAIBatchJobException
from airflow.providers.openai.hooks.openai import OpenAIHook
from airflow.providers.openai.triggers.openai import OpenAIBatchTrigger

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down Expand Up @@ -74,3 +78,91 @@ def execute(self, context: Context) -> list[float]:
embeddings = self.hook.create_embeddings(self.input_text, model=self.model, **self.embedding_kwargs)
self.log.info("Generated embeddings for %d items", len(embeddings))
return embeddings


class OpenAITriggerBatchOperator(BaseOperator):
"""
Operator that triggers an OpenAI Batch API endpoint and waits for the batch to complete.
:param file_id: Required. The ID of the batch file to trigger.
:param endpoint: Required. The OpenAI Batch API endpoint to trigger.
:param conn_id: Optional. The OpenAI connection ID to use. Defaults to 'openai_default'.
:param deferrable: Optional. Run operator in the deferrable mode.
:param wait_seconds: Optional. Number of seconds between checks. Only used when ``deferrable`` is False.
Defaults to 3 seconds.
:param timeout: Optional. The amount of time, in seconds, to wait for the request to complete.
Only used when ``deferrable`` is False. Defaults to 24 hour, which is the SLA for OpenAI Batch API.
:param wait_for_completion: Optional. Whether to wait for the batch to complete. If set to False, the operator
will return immediately after triggering the batch. Defaults to True.
.. seealso::
For more information on how to use this operator, please take a look at the guide:
:ref:`howto/operator:OpenAITriggerBatchOperator`
"""

template_fields: Sequence[str] = ("file_id",)

def __init__(
self,
file_id: str,
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"],
conn_id: str = OpenAIHook.default_conn_name,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
wait_seconds: float = 3,
timeout: float = 24 * 60 * 60,
wait_for_completion: bool = True,
**kwargs: Any,
):
super().__init__(**kwargs)
self.conn_id = conn_id
self.file_id = file_id
self.endpoint = endpoint
self.deferrable = deferrable
self.wait_seconds = wait_seconds
self.timeout = timeout
self.wait_for_completion = wait_for_completion
self.batch_id: str | None = None

@cached_property
def hook(self) -> OpenAIHook:
"""Return an instance of the OpenAIHook."""
return OpenAIHook(conn_id=self.conn_id)

def execute(self, context: Context) -> str:
batch = self.hook.create_batch(file_id=self.file_id, endpoint=self.endpoint)
self.batch_id = batch.id
if self.wait_for_completion:
if self.deferrable:
self.defer(
timeout=self.execution_timeout,
trigger=OpenAIBatchTrigger(
conn_id=self.conn_id,
batch_id=self.batch_id,
poll_interval=60,
end_time=time.time() + self.timeout,
),
method_name="execute_complete",
)
else:
self.log.info("Waiting for batch %s to complete", self.batch_id)
self.hook.wait_for_batch(self.batch_id, wait_seconds=self.wait_seconds, timeout=self.timeout)
return self.batch_id

def execute_complete(self, context: Context, event: Any = None) -> str:
"""
Invoke this callback when the trigger fires; return immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event["status"] == "error":
raise OpenAIBatchJobException(event["message"])

self.log.info("%s completed successfully.", self.task_id)
return event["batch_id"]

def on_kill(self) -> None:
"""Cancel the batch if task is cancelled."""
if self.batch_id:
self.log.info("on_kill: cancel the OpenAI Batch %s", self.batch_id)
self.hook.cancel_batch(self.batch_id)
5 changes: 5 additions & 0 deletions airflow/providers/openai/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ operators:
python-modules:
- airflow.providers.openai.operators.openai

triggers:
- integration-name: OpenAI
python-modules:
- airflow.providers.openai.triggers.openai

connection-types:
- hook-class-name: airflow.providers.openai.hooks.openai.OpenAIHook
connection-type: openai
16 changes: 16 additions & 0 deletions airflow/providers/openai/triggers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
Loading

0 comments on commit 00e73e6

Please sign in to comment.