Skip to content

Commit

Permalink
fix(generativeai): update prompt optimizer sample and tests (GoogleCl…
Browse files Browse the repository at this point in the history
…oudPlatform#12630)

Update samples and tests for Prompt Optimizer. The static resources have been moved to GCS buckets and the tests can now be made simpler.
  • Loading branch information
Sita04 authored and riathakkar committed Oct 8, 2024
1 parent 8c38fa1 commit 50cf5fb
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 158 deletions.
12 changes: 8 additions & 4 deletions generative_ai/batch_predict/gemini_batch_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")


def batch_predict_gemini_createjob(input_uri: str, output_uri: str) -> BatchPredictionJob:
def batch_predict_gemini_createjob(
input_uri: str, output_uri: str
) -> BatchPredictionJob:
"""Perform batch text prediction using a Gemini AI model.
Args:
input_uri (str): URI of the input file in BigQuery table or Google Cloud Storage.
Expand Down Expand Up @@ -47,7 +49,7 @@ def batch_predict_gemini_createjob(input_uri: str, output_uri: str) -> BatchPred
batch_prediction_job = BatchPredictionJob.submit(
source_model="gemini-1.5-flash-001",
input_dataset=input_uri,
output_uri_prefix=output_uri
output_uri_prefix=output_uri,
)

# Check job status
Expand Down Expand Up @@ -82,5 +84,7 @@ def batch_predict_gemini_createjob(input_uri: str, output_uri: str) -> BatchPred
if __name__ == "__main__":
# TODO(developer): Update gsc bucket and file paths
GCS_BUCKET = "gs://yourbucket"
batch_predict_gemini_createjob(f"gs://{GCS_BUCKET}/batch_data/sample_input_file.jsonl",
f"gs://{GCS_BUCKET}/batch_preditions/sample_output/")
batch_predict_gemini_createjob(
f"gs://{GCS_BUCKET}/batch_data/sample_input_file.jsonl",
f"gs://{GCS_BUCKET}/batch_preditions/sample_output/",
)
3 changes: 2 additions & 1 deletion generative_ai/batch_predict/test_batch_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def test_batch_gemini_predict(output_folder: pytest.fixture()) -> None:
input_uri = f"gs://{INPUT_BUCKET}/batch/prompt_for_batch_gemini_predict.jsonl"
job = _main_test(
test_func=lambda: gemini_batch_predict.batch_predict_gemini_createjob(
input_uri, output_folder)
input_uri, output_folder
)
)
assert OUTPUT_PATH in job.output_location
70 changes: 35 additions & 35 deletions generative_ai/prompts/prompt_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,60 +14,60 @@

import os

PROJECT_ID = os.environ["GOOGLE_CLOUD_PROJECT"]

def optimize_prompts(
project: str,
location: str,
staging_bucket: str,
configuration_path: str,

def prompts_custom_job_example(
cloud_bucket: str, config_path: str, output_path: str
) -> str:
"""Improve prompts by evaluating the model's response to sample prompts against specified evaluation metric(s).
Args:
project: Google Cloud Project ID.
location: Location where you want to run the Vertex AI prompt optimizer.
staging_bucket: Specify the Google Cloud Storage bucket to store outputs and metadata. For example, gs://bucket-name
configuration_path: URI of the configuration file in your Google Cloud Storage bucket. For example, gs://bucket-name/configuration.json.
Returns:
custom_job.resource_name: Returns the resource name of the job created of type: projects/project-id/locations/location/customJobs/job-id
cloud_bucket(str): Specify the Google Cloud Storage bucket to store outputs and metadata. For example, gs://bucket-name
config_path(str): Filepath for config file in your Google Cloud Storage bucket. For example, prompts/custom_job/instructions/configuration.json
output_path(str): Filepath of the folder location in your Google Cloud Storage bucket. For example, prompts/custom_job/output
Returns(str):
Resource name of the job created. For example, projects/<project-id>/locations/location/customJobs/<job-id>
"""
# [START generativeaionvertexai_prompt_optimizer]
from google.cloud import aiplatform

# TODO(developer): Update & uncomment below line
# project = "your-gcp-project-id"
# location = "location"
# staging_bucket = "output-bucket-gcs-uri"
# configuration_path = "configuration-file-gcs-uri"
aiplatform.init(project=project, location=location, staging_bucket=staging_bucket)
# Initialize Vertex AI platform
aiplatform.init(project=PROJECT_ID, location="us-central1")

worker_pool_specs = [
{
"replica_count": 1,
"container_spec": {
"image_uri": "us-docker.pkg.dev/vertex-ai-restricted/builtin-algorithm/apd:preview_v1_0",
"args": [f"--config={configuration_path}"],
},
"machine_spec": {
"machine_type": "n1-standard-4",
},
}
]
# TODO(Developer): Check and update lines below
# cloud_bucket = "gs://cloud-samples-data"
# config_path = f"{cloud_bucket}/instructions/sample_configuration.json"
# output_path = "custom_job/output/"

custom_job = aiplatform.CustomJob(
display_name="Prompt Optimizer example",
worker_pool_specs=worker_pool_specs,
worker_pool_specs=[
{
"replica_count": 1,
"container_spec": {
"image_uri": "us-docker.pkg.dev/vertex-ai-restricted/builtin-algorithm/apd:preview_v1_0",
"args": [f"--config={cloud_bucket}/{config_path}"],
},
"machine_spec": {
"machine_type": "n1-standard-4",
},
}
],
staging_bucket=cloud_bucket,
base_output_dir=f"{cloud_bucket}/{output_path}",
)

custom_job.submit()
print(f"Job resource name: {custom_job.resource_name}")

# Example response:
# 'projects/123412341234/locations/us-central1/customJobs/12341234123412341234'
# [END generativeaionvertexai_prompt_optimizer]
return custom_job.resource_name


if __name__ == "__main__":
optimize_prompts(
os.environ["PROJECT_ID"],
"us-central1",
os.environ["PROMPT_OPTIMIZER_BUCKET_NAME"],
prompts_custom_job_example(
os.environ["CLOUD_BUCKET"],
os.environ["JSON_CONFIG_PATH"],
os.environ["OUTPUT_PATH"],
)
140 changes: 26 additions & 114 deletions generative_ai/prompts/test_prompt_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,138 +12,50 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import time
from typing import Callable

from google.cloud import aiplatform, storage
from google.cloud.aiplatform import CustomJob
from google.cloud.aiplatform_v1 import JobState
from google.cloud.exceptions import NotFound
from google.cloud.storage import transfer_manager

from prompt_optimizer import optimize_prompts

import pytest
from prompt_optimizer import prompts_custom_job_example

PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
STAGING_BUCKET_NAME = "prompt_optimizer_bucket"
CONFIGURATION_DIRECTORY = "test_resources"
CONFIGURATION_FILENAME = "sample_configuration.json"
LOCATION = "us-central1"
OUTPUT_PATH = "instruction"

STORAGE_CLIENT = storage.Client()


def _clean_resources(bucket_resource_name: str) -> None:
# delete blobs and bucket if exists
try:
bucket = STORAGE_CLIENT.get_bucket(bucket_resource_name)
except NotFound:
print(f"Bucket {bucket_resource_name} cannot be accessed")
return
CLOUD_BUCKET = "gs://python-docs-samples-tests"
CONFIG_PATH = "ai-platform/prompt_optimization/instructions/sample_configuration.json"
OUTPUT_PATH = "ai-platform/prompt_optimization/output/"

blobs = bucket.list_blobs()
for blob in blobs:
blob.delete()
bucket.delete()


def substitute_env_variable(data: dict, target_key: str, env_var_name: str) -> dict:
# substitute env variables in the given config file with runtime values
if isinstance(data, dict):
for key, value in data.items():
if key == target_key:
data[key] = os.environ.get(env_var_name)
else:
data[key] = substitute_env_variable(value, target_key, env_var_name)
elif isinstance(data, list):
for i, value in enumerate(data):
data[i] = substitute_env_variable(value, target_key, env_var_name)
return data


def update_json() -> dict:
# Load the JSON file
file_path = os.path.join(
os.path.dirname(__file__), CONFIGURATION_DIRECTORY, CONFIGURATION_FILENAME
def test_prompt_optimizer() -> None:
custom_job_name = prompts_custom_job_example(CLOUD_BUCKET, CONFIG_PATH, OUTPUT_PATH)
job = aiplatform.CustomJob.get(
resource_name=custom_job_name, project=PROJECT_ID, location="us-central1"
)
with open(file_path, "r") as f:
data = json.load(f)
# Substitute only the "project" variable with the value of "PROJECT_ID"
substituted_data = substitute_env_variable(data, "project", "PROJECT_ID")
return substituted_data

storage_client = storage.Client()
start_time = time.time()
timeout = 1200

@pytest.fixture(scope="session")
def bucket_name() -> str:
filenames = [
"sample_prompt_template.txt",
"sample_prompts.jsonl",
"sample_system_instruction.txt",
]
# cleanup existing stale resources
_clean_resources(STAGING_BUCKET_NAME)
# create bucket
bucket = STORAGE_CLIENT.bucket(STAGING_BUCKET_NAME)
bucket.storage_class = "STANDARD"
new_bucket = STORAGE_CLIENT.create_bucket(bucket, location="us")
# update JSON to substitute env variables
substituted_data = update_json()
# convert the JSON data to a byte string
json_str = json.dumps(substituted_data, indent=2)
json_bytes = json_str.encode("utf-8")
# upload substituted JSON file to the bucket
blob = bucket.blob(CONFIGURATION_FILENAME)
blob.upload_from_string(json_bytes)
# upload config files to the bucket
transfer_manager.upload_many_from_filenames(
new_bucket,
filenames,
source_directory=os.path.join(
os.path.dirname(__file__), CONFIGURATION_DIRECTORY
),
)
yield new_bucket.name
_clean_resources(new_bucket.name)


def _main_test(test_func: Callable) -> None:
job_resource_name: str = ""
timeout = 900 # seconds
# wait for the job to complete
try:
job_resource_name = test_func()
start_time = time.time()
while (
get_job(job_resource_name).state
not in [JobState.JOB_STATE_SUCCEEDED, JobState.JOB_STATE_FAILED]
job.state not in [JobState.JOB_STATE_SUCCEEDED, JobState.JOB_STATE_FAILED]
and time.time() - start_time < timeout
):
print(f"Waiting for the CustomJob({job.resource_name}) to be ready!")
time.sleep(10)
assert (
storage_client.get_bucket(CLOUD_BUCKET.split("gs://")[-1]).list_blobs(
prefix=OUTPUT_PATH
)
is not None
)
finally:
# delete job
get_job(job_resource_name).delete()


def test_prompt_optimizer(bucket_name: pytest.fixture()) -> None:
_main_test(
test_func=lambda: optimize_prompts(
PROJECT_ID,
LOCATION,
f"gs://{bucket_name}",
f"gs://{bucket_name}/{CONFIGURATION_FILENAME}",
print(f"CustomJob({job.resource_name}) to be ready. Delete it now.")
job.delete()
# delete output blob
blobs = storage_client.get_bucket(CLOUD_BUCKET.split("gs://")[-1]).list_blobs(
prefix=OUTPUT_PATH
)
)
assert (
STORAGE_CLIENT.get_bucket(bucket_name).list_blobs(prefix=OUTPUT_PATH)
is not None
)


def get_job(job_resource_name: str) -> CustomJob:
return aiplatform.CustomJob.get(
resource_name=job_resource_name, project=PROJECT_ID, location=LOCATION
)
for blob in blobs:
blob.delete()
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
{
"project": "$PROJECT_ID",
"system_instruction_path": "gs://prompt_optimizer_bucket/sample_system_instruction.txt",
"prompt_template_path": "gs://prompt_optimizer_bucket/sample_prompt_template.txt",
"system_instruction_path": "gs://$CLOUD_BUCKET/sample_system_instruction.txt",
"prompt_template_path": "gs://$CLOUD_BUCKET/sample_prompt_template.txt",
"target_model": "gemini-1.5-flash-001",
"eval_metrics_types": ["safety"],
"optimization_mode": "instruction",
"input_data_path": "gs://prompt_optimizer_bucket/sample_prompts.jsonl",
"output_path": "gs://prompt_optimizer_bucket",
"input_data_path": "gs://$CLOUD_BUCKET/sample_prompts.jsonl",
"output_path": "gs://$CLOUD_BUCKET",
"eval_metrics_weights": [1]
}

0 comments on commit 50cf5fb

Please sign in to comment.