Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion providers/google/docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,36 @@
Changelog
---------

16.0.0
......

Breaking changes
~~~~~~~~~~~~~~~~

.. warning::
Deprecated classes, parameters and features have been removed from the Google provider package.
The following breaking changes were introduced:

* Operators

* ``Remove operator TextGenerationModelPredictOperator use GenerativeModelGenerateContentOperator instead``

* Hooks

* ``Remove GenerativeModelHook.get_text_generation_model use GenerativeModelHook.get_generative_model instead``
* ``Remove GenerativeModelHook.text_generation_model_predict use GenerativeModelHook.generative_model_generate_content instead``
* ``Remove split_tablename function from airflow.providers.google.cloud.hooks.bigquery``
``use airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.split_tablename instead``

* Links

* ``Remove AutoMLDatasetLink use TranslationLegacyDatasetLink instead``
* ``Remove AutoMLDatasetListLink use TranslationDatasetListLink instead``
* ``Remove AutoMLModelLink use TranslationLegacyModelLink instead``
* ``Remove AutoMLModelTrainLink use TranslationLegacyModelTrainLink instead``
* ``Remove AutoMLModelPredictLink use TranslationLegacyModelPredictLink``


15.1.0
......

Expand Down Expand Up @@ -60,7 +90,7 @@ Misc
Misc
~~~~

* ``Deprecate Life Scrience service (#48862)``
* ``Deprecate Life Science service (#48862)``

Doc-only
~~~~~~~~
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1992,74 +1992,6 @@ def _escape(s: str) -> str:
return e


@deprecated(
planned_removal_date="April 01, 2025",
use_instead="airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.split_tablename",
category=AirflowProviderDeprecationWarning,
)
def split_tablename(
table_input: str, default_project_id: str, var_name: str | None = None
) -> tuple[str, str, str]:
if "." not in table_input:
raise ValueError(f"Expected table name in the format of <dataset>.<table>. Got: {table_input}")

if not default_project_id:
raise ValueError("INTERNAL: No default project is specified")

def var_print(var_name):
if var_name is None:
return ""
return f"Format exception for {var_name}: "

if table_input.count(".") + table_input.count(":") > 3:
raise ValueError(f"{var_print(var_name)}Use either : or . to specify project got {table_input}")
cmpt = table_input.rsplit(":", 1)
project_id = None
rest = table_input
if len(cmpt) == 1:
project_id = None
rest = cmpt[0]
elif len(cmpt) == 2 and cmpt[0].count(":") <= 1:
if cmpt[-1].count(".") != 2:
project_id = cmpt[0]
rest = cmpt[1]
else:
raise ValueError(
f"{var_print(var_name)}Expect format of (<project:)<dataset>.<table>, got {table_input}"
)

cmpt = rest.split(".")
if len(cmpt) == 3:
if project_id:
raise ValueError(f"{var_print(var_name)}Use either : or . to specify project")
project_id = cmpt[0]
dataset_id = cmpt[1]
table_id = cmpt[2]

elif len(cmpt) == 2:
dataset_id = cmpt[0]
table_id = cmpt[1]
else:
raise ValueError(
f"{var_print(var_name)}Expect format of (<project.|<project:)<dataset>.<table>, got {table_input}"
)

# Exclude partition from the table name
table_id = table_id.split("$")[0]

if project_id is None:
if var_name is not None:
log.info(
'Project is not included in %s: %s; using project "%s"',
var_name,
table_input,
default_project_id,
)
project_id = default_project_id

return project_id, dataset_id, table_id


def _cleanse_time_partitioning(
destination_dataset_table: str | None, time_partitioning_in: dict | None
) -> dict: # if it is a partitioned table ($ is in the table name) add partition load option
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,12 @@

import vertexai
from vertexai.generative_models import GenerativeModel
from vertexai.language_models import TextEmbeddingModel, TextGenerationModel
from vertexai.language_models import TextEmbeddingModel
from vertexai.preview.caching import CachedContent
from vertexai.preview.evaluation import EvalResult, EvalTask
from vertexai.preview.generative_models import GenerativeModel as preview_generative_model
from vertexai.preview.tuning import sft

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.google.common.deprecated import deprecated
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook

if TYPE_CHECKING:
Expand All @@ -43,16 +41,6 @@
class GenerativeModelHook(GoogleBaseHook):
"""Hook for Google Cloud Vertex AI Generative Model APIs."""

@deprecated(
planned_removal_date="April 09, 2025",
use_instead="GenerativeModelHook.get_generative_model",
category=AirflowProviderDeprecationWarning,
)
def get_text_generation_model(self, pretrained_model: str):
"""Return a Model Garden Model object based on Text Generation."""
model = TextGenerationModel.from_pretrained(pretrained_model)
return model

def get_text_embedding_model(self, pretrained_model: str):
"""Return a Model Garden Model object based on Text Embedding."""
model = TextEmbeddingModel.from_pretrained(pretrained_model)
Expand Down Expand Up @@ -100,59 +88,6 @@ def get_cached_context_model(
cached_context_model = preview_generative_model.from_cached_content(cached_content)
return cached_context_model

@deprecated(
planned_removal_date="April 09, 2025",
use_instead="GenerativeModelHook.generative_model_generate_content",
category=AirflowProviderDeprecationWarning,
)
@GoogleBaseHook.fallback_to_default_project_id
def text_generation_model_predict(
self,
prompt: str,
pretrained_model: str,
temperature: float,
max_output_tokens: int,
top_p: float,
top_k: int,
location: str,
project_id: str = PROVIDE_PROJECT_ID,
) -> str:
"""
Use the Vertex AI PaLM API to generate natural language text.

:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
:param location: Required. The ID of the Google Cloud location that the service belongs to.
:param prompt: Required. Inputs or queries that a user or a program gives
to the Vertex AI PaLM API, in order to elicit a specific response.
:param pretrained_model: A pre-trained model optimized for performing natural
language tasks such as classification, summarization, extraction, content
creation, and ideation.
:param temperature: Temperature controls the degree of randomness in token
selection.
:param max_output_tokens: Token limit determines the maximum amount of text
output.
:param top_p: Tokens are selected from most probable to least until the sum
of their probabilities equals the top_p value. Defaults to 0.8.
:param top_k: A top_k of 1 means the selected token is the most probable
among all tokens.
"""
vertexai.init(project=project_id, location=location, credentials=self.get_credentials())

parameters = {
"temperature": temperature,
"max_output_tokens": max_output_tokens,
"top_p": top_p,
"top_k": top_k,
}

model = self.get_text_generation_model(pretrained_model)

response = model.predict(
prompt=prompt,
**parameters,
)
return response.text

@GoogleBaseHook.fallback_to_default_project_id
def text_embedding_model_get_embeddings(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,105 +22,13 @@
from collections.abc import Sequence
from typing import TYPE_CHECKING

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.google.cloud.hooks.vertex_ai.generative_model import GenerativeModelHook
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
from airflow.providers.google.common.deprecated import deprecated

if TYPE_CHECKING:
from airflow.utils.context import Context


@deprecated(
planned_removal_date="April 09, 2025",
use_instead="GenerativeModelGenerateContentOperator",
category=AirflowProviderDeprecationWarning,
)
class TextGenerationModelPredictOperator(GoogleCloudBaseOperator):
"""
Uses the Vertex AI PaLM API to generate natural language text.

:param project_id: Required. The ID of the Google Cloud project that the
service belongs to (templated).
:param location: Required. The ID of the Google Cloud location that the
service belongs to (templated).
:param prompt: Required. Inputs or queries that a user or a program gives
to the Vertex AI PaLM API, in order to elicit a specific response (templated).
:param pretrained_model: By default uses the pre-trained model `text-bison`,
optimized for performing natural language tasks such as classification,
summarization, extraction, content creation, and ideation.
:param temperature: Temperature controls the degree of randomness in token
selection. Defaults to 0.0.
:param max_output_tokens: Token limit determines the maximum amount of text
output. Defaults to 256.
:param top_p: Tokens are selected from most probable to least until the sum
of their probabilities equals the top_p value. Defaults to 0.8.
:param top_k: A top_k of 1 means the selected token is the most probable
among all tokens. Defaults to 0.4.
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
"""

template_fields = ("location", "project_id", "impersonation_chain", "prompt")

def __init__(
self,
*,
project_id: str,
location: str,
prompt: str,
pretrained_model: str = "text-bison",
temperature: float = 0.0,
max_output_tokens: int = 256,
top_p: float = 0.8,
top_k: int = 40,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.project_id = project_id
self.location = location
self.prompt = prompt
self.pretrained_model = pretrained_model
self.temperature = temperature
self.max_output_tokens = max_output_tokens
self.top_p = top_p
self.top_k = top_k
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain

def execute(self, context: Context):
self.hook = GenerativeModelHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)

self.log.info("Submitting prompt")
response = self.hook.text_generation_model_predict(
project_id=self.project_id,
location=self.location,
prompt=self.prompt,
pretrained_model=self.pretrained_model,
temperature=self.temperature,
max_output_tokens=self.max_output_tokens,
top_p=self.top_p,
top_k=self.top_k,
)

self.log.info("Model response: %s", response)
self.xcom_push(context, key="model_response", value=response)

return response


class TextEmbeddingModelGetEmbeddingsOperator(GoogleCloudBaseOperator):
"""
Uses the Vertex AI Embeddings API to generate embeddings based on prompt.
Expand Down
57 changes: 0 additions & 57 deletions providers/google/tests/unit/google/cloud/hooks/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
_format_schema_for_description,
_validate_src_fmt_configs,
_validate_value,
split_tablename,
)

pytestmark = pytest.mark.filterwarnings("error::airflow.exceptions.AirflowProviderDeprecationWarning")
Expand Down Expand Up @@ -772,62 +771,6 @@ def test_split_tablename_invalid_syntax(self, table_input, var_name, exception_m
self.hook.split_tablename(table_input, default_project_id, var_name)


class TestBigQueryTableSplitter:
def test_internal_need_default_project(self):
with pytest.raises(AirflowProviderDeprecationWarning):
split_tablename("dataset.table", None)

@pytest.mark.parametrize("partition", ["$partition", ""])
@pytest.mark.parametrize(
"project_expected, dataset_expected, table_expected, table_input",
[
("project", "dataset", "table", "dataset.table"),
("alternative", "dataset", "table", "alternative:dataset.table"),
("alternative", "dataset", "table", "alternative.dataset.table"),
("alt1:alt", "dataset", "table", "alt1:alt.dataset.table"),
("alt1:alt", "dataset", "table", "alt1:alt:dataset.table"),
],
)
def test_split_tablename(
self, project_expected, dataset_expected, table_expected, table_input, partition
):
default_project_id = "project"
with pytest.raises(AirflowProviderDeprecationWarning):
split_tablename(table_input + partition, default_project_id)

@pytest.mark.parametrize(
"table_input, var_name, exception_message",
[
("alt1:alt2:alt3:dataset.table", None, "Use either : or . to specify project got {}"),
(
"alt1.alt.dataset.table",
None,
r"Expect format of \(<project\.\|<project\:\)<dataset>\.<table>, got {}",
),
(
"alt1:alt2:alt.dataset.table",
"var_x",
"Format exception for var_x: Use either : or . to specify project got {}",
),
(
"alt1:alt2:alt:dataset.table",
"var_x",
"Format exception for var_x: Use either : or . to specify project got {}",
),
(
"alt1.alt.dataset.table",
"var_x",
r"Format exception for var_x: Expect format of "
r"\(<project\.\|<project:\)<dataset>.<table>, got {}",
),
],
)
def test_invalid_syntax(self, table_input, var_name, exception_message):
default_project_id = "project"
with pytest.raises(AirflowProviderDeprecationWarning):
split_tablename(table_input, default_project_id, var_name)


@pytest.mark.db_test
class TestTableOperations(_BigQueryBaseTestClass):
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.Table")
Expand Down
Loading