Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 27 additions & 21 deletions airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ def prompt_language_model(
"""
Use the Vertex AI PaLM API to generate natural language text.

:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
:param location: Required. The ID of the Google Cloud location that the service belongs to.
:param prompt: Required. Inputs or queries that a user or a program gives
to the Vertex AI PaLM API, in order to elicit a specific response.
:param pretrained_model: A pre-trained model optimized for performing natural
Expand All @@ -141,8 +143,6 @@ def prompt_language_model(
of their probabilities equals the top_p value. Defaults to 0.8.
:param top_k: A top_k of 1 means the selected token is the most probable
among all tokens.
:param location: Required. The ID of the Google Cloud location that the service belongs to.
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
"""
vertexai.init(project=project_id, location=location, credentials=self.get_credentials())

Expand Down Expand Up @@ -178,11 +178,11 @@ def generate_text_embeddings(
"""
Use the Vertex AI PaLM API to generate text embeddings.

:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
:param location: Required. The ID of the Google Cloud location that the service belongs to.
:param prompt: Required. Inputs or queries that a user or a program gives
to the Vertex AI PaLM API, in order to elicit a specific response.
:param pretrained_model: A pre-trained model optimized for generating text embeddings.
:param location: Required. The ID of the Google Cloud location that the service belongs to.
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
"""
vertexai.init(project=project_id, location=location, credentials=self.get_credentials())
model = self.get_text_embedding_model(pretrained_model)
Expand Down Expand Up @@ -210,16 +210,16 @@ def prompt_multimodal_model(
"""
Use the Vertex AI Gemini Pro foundation model to generate natural language text.

:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
:param location: Required. The ID of the Google Cloud location that the service belongs to.
:param prompt: Required. Inputs or queries that a user or a program gives
to the Multi-modal model, in order to elicit a specific response.
:param location: Required. The ID of the Google Cloud location that the service belongs to.
:param generation_config: Optional. Generation configuration settings.
:param safety_settings: Optional. Per request settings for blocking unsafe content.
:param pretrained_model: By default uses the pre-trained model `gemini-pro`,
supporting prompts with text-only input, including natural language
tasks, multi-turn text and code chat, and code generation. It can
output text and code.
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
"""
vertexai.init(project=project_id, location=location, credentials=self.get_credentials())

Expand Down Expand Up @@ -251,6 +251,8 @@ def prompt_multimodal_model_with_media(
"""
Use the Vertex AI Gemini Pro foundation model to generate natural language text.

:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
:param location: Required. The ID of the Google Cloud location that the service belongs to.
:param prompt: Required. Inputs or queries that a user or a program gives
to the Multi-modal model, in order to elicit a specific response.
:param generation_config: Optional. Generation configuration settings.
Expand All @@ -262,8 +264,6 @@ def prompt_multimodal_model_with_media(
:param media_gcs_path: A GCS path to a content file such as an image or a video.
Can be passed to the multi-modal model as part of the prompt. Used with vision models.
:param mime_type: Validates the media type presented by the file in the media_gcs_path.
:param location: Required. The ID of the Google Cloud location that the service belongs to.
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
"""
vertexai.init(project=project_id, location=location, credentials=self.get_credentials())

Expand All @@ -290,6 +290,8 @@ def text_generation_model_predict(
"""
Use the Vertex AI PaLM API to generate natural language text.

:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
:param location: Required. The ID of the Google Cloud location that the service belongs to.
:param prompt: Required. Inputs or queries that a user or a program gives
to the Vertex AI PaLM API, in order to elicit a specific response.
:param pretrained_model: A pre-trained model optimized for performing natural
Expand All @@ -303,8 +305,6 @@ def text_generation_model_predict(
of their probabilities equals the top_p value. Defaults to 0.8.
:param top_k: A top_k of 1 means the selected token is the most probable
among all tokens.
:param location: Required. The ID of the Google Cloud location that the service belongs to.
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
"""
vertexai.init(project=project_id, location=location, credentials=self.get_credentials())

Expand Down Expand Up @@ -334,11 +334,11 @@ def text_embedding_model_get_embeddings(
"""
Use the Vertex AI PaLM API to generate text embeddings.

:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
:param location: Required. The ID of the Google Cloud location that the service belongs to.
:param prompt: Required. Inputs or queries that a user or a program gives
to the Vertex AI PaLM API, in order to elicit a specific response.
:param pretrained_model: A pre-trained model optimized for generating text embeddings.
:param location: Required. The ID of the Google Cloud location that the service belongs to.
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
"""
vertexai.init(project=project_id, location=location, credentials=self.get_credentials())
model = self.get_text_embedding_model(pretrained_model)
Expand All @@ -355,26 +355,31 @@ def generative_model_generate_content(
tools: list | None = None,
generation_config: dict | None = None,
safety_settings: dict | None = None,
system_instruction: str | None = None,
pretrained_model: str = "gemini-pro",
project_id: str = PROVIDE_PROJECT_ID,
) -> str:
"""
Use the Vertex AI Gemini Pro foundation model to generate natural language text.

:param location: Required. The ID of the Google Cloud location that the service belongs to.
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
:param contents: Required. The multi-part content of a message that a user or a program
gives to the generative model, in order to elicit a specific response.
:param location: Required. The ID of the Google Cloud location that the service belongs to.
:param generation_config: Optional. Generation configuration settings.
:param safety_settings: Optional. Per request settings for blocking unsafe content.
:param tools: Optional. A list of tools available to the model during evaluation, such as a data store.
:param system_instruction: Optional. An instruction given to the model to guide its behavior.
:param pretrained_model: By default uses the pre-trained model `gemini-pro`,
supporting prompts with text-only input, including natural language
tasks, multi-turn text and code chat, and code generation. It can
output text and code.
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
"""
vertexai.init(project=project_id, location=location, credentials=self.get_credentials())

model = self.get_generative_model(pretrained_model)
model = self.get_generative_model(
pretrained_model=pretrained_model, system_instruction=system_instruction
)
response = model.generate_content(
contents=contents,
tools=tools,
Expand All @@ -400,12 +405,13 @@ def supervised_fine_tuning_train(
"""
Use the Supervised Fine Tuning API to create a tuning job.

:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
:param location: Required. The ID of the Google Cloud location that the service belongs to.
:param source_model: Required. A pre-trained model optimized for performing natural
language tasks such as classification, summarization, extraction, content
creation, and ideation.
:param train_dataset: Required. Cloud Storage URI of your training dataset. The dataset
must be formatted as a JSONL file. For best results, provide at least 100 to 500 examples.
:param location: Required. The ID of the Google Cloud location that the service belongs to.
:param tuned_model_display_name: Optional. Display name of the TunedModel. The name can be up
to 128 characters long and can consist of any UTF-8 characters.
:param validation_dataset: Optional. Cloud Storage URI of your training dataset. The dataset must be
Expand Down Expand Up @@ -447,18 +453,18 @@ def count_tokens(
"""
Use the Vertex AI Count Tokens API to calculate the number of input tokens before sending a request to the Gemini API.

:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
:param location: Required. The ID of the Google Cloud location that the service belongs to.
:param contents: Required. The multi-part content of a message that a user or a program
gives to the generative model, in order to elicit a specific response.
:param location: Required. The ID of the Google Cloud location that the service belongs to.
:param pretrained_model: By default uses the pre-trained model `gemini-pro`,
supporting prompts with text-only input, including natural language
tasks, multi-turn text and code chat, and code generation. It can
output text and code.
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
"""
vertexai.init(project=project_id, location=location, credentials=self.get_credentials())

model = self.get_generative_model(pretrained_model)
model = self.get_generative_model(pretrained_model=pretrained_model)
response = model.count_tokens(
contents=contents,
)
Expand All @@ -484,6 +490,8 @@ def run_evaluation(
"""
Use the Rapid Evaluation API to evaluate a model.

:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
:param location: Required. The ID of the Google Cloud location that the service belongs to.
:param pretrained_model: Required. A pre-trained model optimized for performing natural
language tasks such as classification, summarization, extraction, content
creation, and ideation.
Expand All @@ -492,8 +500,6 @@ def run_evaluation(
:param experiment_name: Required. The name of the evaluation experiment.
:param experiment_run_name: Required. The specific run name or ID for this experiment.
:param prompt_template: Required. The template used to format the model's prompts during evaluation. Adheres to Rapid Evaluation API.
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
:param location: Required. The ID of the Google Cloud location that the service belongs to.
:param generation_config: Optional. A dictionary containing generation parameters for the model.
:param safety_settings: Optional. A dictionary specifying harm category thresholds for blocking model outputs.
:param system_instruction: Optional. An instruction given to the model to guide its behavior.
Expand Down
Loading