Skip to content

Commit

Permalink
Create Operators for Google Cloud Vertex AI Context Caching (#43008)
Browse files Browse the repository at this point in the history
* Fix merge conflicts

* Fix documentation.

* Update return variables.
  • Loading branch information
CYarros10 authored Oct 15, 2024
1 parent c9c4ca5 commit f1f9201
Show file tree
Hide file tree
Showing 8 changed files with 442 additions and 9 deletions.
20 changes: 20 additions & 0 deletions docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,26 @@ The operator returns the evaluation summary metrics in :ref:`XCom <concepts:xcom
:start-after: [START how_to_cloud_vertex_ai_run_evaluation_operator]
:end-before: [END how_to_cloud_vertex_ai_run_evaluation_operator]

To create cached content you can use
:class:`~airflow.providers.google.cloud.operators.vertex_ai.generative_model.CreateCachedContentOperator`.
The operator returns the cached content resource name in :ref:`XCom <concepts:xcom>` under ``return_value`` key.

.. exampleinclude:: /../../providers/tests/system/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
:language: python
:dedent: 4
:start-after: [START how_to_cloud_vertex_ai_create_cached_content_operator]
:end-before: [END how_to_cloud_vertex_ai_create_cached_content_operator]

To generate a response from cached content you can use
:class:`~airflow.providers.google.cloud.operators.vertex_ai.generative_model.GenerateFromCachedContentOperator`.
The operator returns the cached content response in :ref:`XCom <concepts:xcom>` under ``return_value`` key.

.. exampleinclude:: /../../providers/tests/system/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
:language: python
:dedent: 4
:start-after: [START how_to_cloud_vertex_ai_generate_from_cached_content_operator]
:end-before: [END how_to_cloud_vertex_ai_generate_from_cached_content_operator]

Reference
^^^^^^^^^

Expand Down
2 changes: 1 addition & 1 deletion generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,7 @@
"google-api-python-client>=2.0.2",
"google-auth-httplib2>=0.0.1",
"google-auth>=2.29.0",
"google-cloud-aiplatform>=1.63.0",
"google-cloud-aiplatform>=1.70.0",
"google-cloud-automl>=2.12.0",
"google-cloud-batch>=0.13.0",
"google-cloud-bigquery-datatransfer>=3.13.0",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@
from __future__ import annotations

import time
from datetime import timedelta
from typing import TYPE_CHECKING, Sequence

import vertexai
from vertexai.generative_models import GenerativeModel, Part
from vertexai.language_models import TextEmbeddingModel, TextGenerationModel
from vertexai.preview.caching import CachedContent
from vertexai.preview.evaluation import EvalResult, EvalTask
from vertexai.preview.generative_models import GenerativeModel as preview_generative_model
from vertexai.preview.tuning import sft

from airflow.exceptions import AirflowProviderDeprecationWarning
Expand Down Expand Up @@ -95,6 +98,16 @@ def get_eval_task(
)
return eval_task

def get_cached_context_model(
self,
cached_content_name: str,
) -> preview_generative_model:
"""Return a Generative Model with Cached Context."""
cached_content = CachedContent(cached_content_name=cached_content_name)

cached_context_model = preview_generative_model.from_cached_content(cached_content)
return cached_context_model

@deprecated(
planned_removal_date="January 01, 2025",
use_instead="Part objects included in contents parameter of "
Expand Down Expand Up @@ -528,3 +541,69 @@ def run_evaluation(
)

return eval_result

def create_cached_content(
self,
model_name: str,
location: str,
ttl_hours: float = 1,
system_instruction: str | None = None,
contents: list | None = None,
display_name: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
) -> str:
"""
Create CachedContent to reduce the cost of requests that contain repeat content with high input token counts.
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
:param location: Required. The ID of the Google Cloud location that the service belongs to.
:param model_name: Required. The name of the publisher model to use for cached content.
:param system_instruction: Developer set system instruction.
:param contents: The content to cache.
:param ttl_hours: The TTL for this resource in hours. The expiration time is computed: now + TTL.
Defaults to one hour.
:param display_name: The user-generated meaningful display name of the cached content
"""
vertexai.init(project=project_id, location=location, credentials=self.get_credentials())

response = CachedContent.create(
model_name=model_name,
system_instruction=system_instruction,
contents=contents,
ttl=timedelta(hours=ttl_hours),
display_name=display_name,
)

return response.name

def generate_from_cached_content(
self,
location: str,
cached_content_name: str,
contents: list,
generation_config: dict | None = None,
safety_settings: dict | None = None,
project_id: str = PROVIDE_PROJECT_ID,
) -> str:
"""
Generate a response from CachedContent.
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
:param location: Required. The ID of the Google Cloud location that the service belongs to.
:param cached_content_name: Required. The name of the cached content resource.
:param contents: Required. The multi-part content of a message that a user or a program
gives to the generative model, in order to elicit a specific response.
:param generation_config: Optional. Generation configuration settings.
:param safety_settings: Optional. Per request settings for blocking unsafe content.
"""
vertexai.init(project=project_id, location=location, credentials=self.get_credentials())

cached_context_model = self.get_cached_context_model(cached_content_name=cached_content_name)

response = cached_context_model.generate_content(
contents=contents,
generation_config=generation_config,
safety_settings=safety_settings,
)

return response.text
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@

from typing import TYPE_CHECKING, Sequence

from google.cloud.aiplatform_v1beta1 import types as types_v1beta1

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.google.cloud.hooks.vertex_ai.generative_model import GenerativeModelHook
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
Expand Down Expand Up @@ -742,8 +740,6 @@ def execute(self, context: Context):
self.xcom_push(context, key="total_tokens", value=response.total_tokens)
self.xcom_push(context, key="total_billable_characters", value=response.total_billable_characters)

return types_v1beta1.CountTokensResponse.to_dict(response)


class RunEvaluationOperator(GoogleCloudBaseOperator):
"""
Expand Down Expand Up @@ -842,3 +838,155 @@ def execute(self, context: Context):
)

return response.summary_metrics


class CreateCachedContentOperator(GoogleCloudBaseOperator):
"""
Create CachedContent to reduce the cost of requests that contain repeat content with high input token counts.
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
:param location: Required. The ID of the Google Cloud location that the service belongs to.
:param model_name: Required. The name of the publisher model to use for cached content.
:param system_instruction: Developer set system instruction.
:param contents: The content to cache.
:param ttl_hours: The TTL for this resource in hours. The expiration time is computed: now + TTL.
Defaults to one hour.
:param display_name: The user-generated meaningful display name of the cached content
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
"""

template_fields = (
"location",
"project_id",
"impersonation_chain",
"model_name",
"contents",
"system_instruction",
)

def __init__(
self,
*,
project_id: str,
location: str,
model_name: str,
system_instruction: str | None = None,
contents: list | None = None,
ttl_hours: float = 1,
display_name: str | None = None,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)

self.project_id = project_id
self.location = location
self.model_name = model_name
self.system_instruction = system_instruction
self.contents = contents
self.ttl_hours = ttl_hours
self.display_name = display_name
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain

def execute(self, context: Context):
self.hook = GenerativeModelHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)

cached_content_name = self.hook.create_cached_content(
project_id=self.project_id,
location=self.location,
model_name=self.model_name,
system_instruction=self.system_instruction,
contents=self.contents,
ttl_hours=self.ttl_hours,
display_name=self.display_name,
)

self.log.info("Cached Content Name: %s", cached_content_name)

return cached_content_name


class GenerateFromCachedContentOperator(GoogleCloudBaseOperator):
"""
Generate a response from CachedContent.
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
:param location: Required. The ID of the Google Cloud location that the service belongs to.
:param cached_content_name: Required. The name of the cached content resource.
:param contents: Required. The multi-part content of a message that a user or a program
gives to the generative model, in order to elicit a specific response.
:param generation_config: Optional. Generation configuration settings.
:param safety_settings: Optional. Per request settings for blocking unsafe content.
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
"""

template_fields = (
"location",
"project_id",
"impersonation_chain",
"cached_content_name",
"contents",
)

def __init__(
self,
*,
project_id: str,
location: str,
cached_content_name: str,
contents: list,
generation_config: dict | None = None,
safety_settings: dict | None = None,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)

self.project_id = project_id
self.location = location
self.cached_content_name = cached_content_name
self.contents = contents
self.generation_config = generation_config
self.safety_settings = safety_settings
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain

def execute(self, context: Context):
self.hook = GenerativeModelHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
cached_content_text = self.hook.generate_from_cached_content(
project_id=self.project_id,
location=self.location,
cached_content_name=self.cached_content_name,
contents=self.contents,
generation_config=self.generation_config,
safety_settings=self.safety_settings,
)

self.log.info("Cached Content Response: %s", cached_content_text)

return cached_content_text
2 changes: 1 addition & 1 deletion providers/src/airflow/providers/google/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ dependencies:
- google-api-python-client>=2.0.2
- google-auth>=2.29.0
- google-auth-httplib2>=0.0.1
- google-cloud-aiplatform>=1.63.0
- google-cloud-aiplatform>=1.70.0
- google-cloud-automl>=2.12.0
# Excluded versions contain bug https://github.com/apache/airflow/issues/39541 which is resolved in 3.24.0
- google-cloud-bigquery>=3.4.0,!=3.21.*,!=3.22.0,!=3.23.*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@

# For no Pydantic environment, we need to skip the tests
pytest.importorskip("google.cloud.aiplatform_v1")
from vertexai.generative_models import HarmBlockThreshold, HarmCategory, Tool, grounding
from datetime import timedelta

from vertexai.generative_models import HarmBlockThreshold, HarmCategory, Part, Tool, grounding
from vertexai.preview.evaluation import MetricPromptTemplateExamples

from airflow.providers.google.cloud.hooks.vertex_ai.generative_model import (
Expand Down Expand Up @@ -106,6 +108,27 @@
TEST_EXPERIMENT_RUN_NAME = "eval-experiment-airflow-operator-run"
TEST_PROMPT_TEMPLATE = "{instruction}. Article: {context}. Summary:"

TEST_CACHED_CONTENT_NAME = "test-example-cache"
TEST_CACHED_CONTENT_PROMPT = ["What are these papers about?"]
TEST_CACHED_MODEL = "gemini-1.5-pro-002"
TEST_CACHED_SYSTEM_INSTRUCTION = """
You are an expert researcher. You always stick to the facts in the sources provided, and never make up new facts.
Now look at these research papers, and answer the following questions.
"""

TEST_CACHED_CONTENTS = [
Part.from_uri(
"gs://cloud-samples-data/generative-ai/pdf/2312.11805v3.pdf",
mime_type="application/pdf",
),
Part.from_uri(
"gs://cloud-samples-data/generative-ai/pdf/2403.05530.pdf",
mime_type="application/pdf",
),
]
TEST_CACHED_TTL = 1
TEST_CACHED_DISPLAY_NAME = "test-example-cache"

BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}"
GENERATIVE_MODEL_STRING = "airflow.providers.google.cloud.hooks.vertex_ai.generative_model.{}"

Expand Down Expand Up @@ -299,3 +322,38 @@ def test_run_evaluation(self, mock_eval_task, mock_model) -> None:
prompt_template=TEST_PROMPT_TEMPLATE,
experiment_run_name=TEST_EXPERIMENT_RUN_NAME,
)

@mock.patch("vertexai.preview.caching.CachedContent.create")
def test_create_cached_content(self, mock_cached_content_create) -> None:
self.hook.create_cached_content(
project_id=GCP_PROJECT,
location=GCP_LOCATION,
model_name=TEST_CACHED_MODEL,
system_instruction=TEST_CACHED_SYSTEM_INSTRUCTION,
contents=TEST_CACHED_CONTENTS,
ttl_hours=TEST_CACHED_TTL,
display_name=TEST_CACHED_DISPLAY_NAME,
)

mock_cached_content_create.assert_called_once_with(
model_name=TEST_CACHED_MODEL,
system_instruction=TEST_CACHED_SYSTEM_INSTRUCTION,
contents=TEST_CACHED_CONTENTS,
ttl=timedelta(hours=TEST_CACHED_TTL),
display_name=TEST_CACHED_DISPLAY_NAME,
)

@mock.patch(GENERATIVE_MODEL_STRING.format("GenerativeModelHook.get_cached_context_model"))
def test_generate_from_cached_content(self, mock_cached_context_model) -> None:
self.hook.generate_from_cached_content(
project_id=GCP_PROJECT,
location=GCP_LOCATION,
cached_content_name=TEST_CACHED_CONTENT_NAME,
contents=TEST_CACHED_CONTENT_PROMPT,
)

mock_cached_context_model.return_value.generate_content.assert_called_once_with(
contents=TEST_CACHED_CONTENT_PROMPT,
generation_config=None,
safety_settings=None,
)
Loading

0 comments on commit f1f9201

Please sign in to comment.