Skip to content

Commit f0a4a00

Browse files
sararobcopybara-github
authored andcommitted
fix: GenAI SDK (prompts) - Fix bug where passing encryption_spec to prompts.create raised an error
PiperOrigin-RevId: 821762191
1 parent 8b9ed04 commit f0a4a00

File tree

3 files changed

+36
-18
lines changed

3 files changed

+36
-18
lines changed

tests/unit/vertexai/genai/replays/test_create_prompt.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
from tests.unit.vertexai.genai.replays import pytest_helper
1818
from vertexai._genai import types
1919
from google.genai import types as genai_types
20-
2120
import pytest
2221

22+
2323
TEST_PROMPT_DATASET_ID = "8005484238453342208"
2424
TEST_VARIABLES = [
2525
{"name": genai_types.Part(text="Alice")},
@@ -289,6 +289,36 @@ def test_create_with_file_data(client):
289289
assert contents[0] == prompt_resource.prompt_data.contents[0]
290290

291291

292+
def test_create_with_encryption_spec(client):
293+
encryption_spec = genai_types.EncryptionSpec(
294+
kms_key_name="projects/vertex-sdk-dev/locations/us-central1/keyRings/my-key-ring/cryptoKeys/my-key",
295+
)
296+
config = types.CreatePromptConfig(
297+
prompt_display_name="my_prompt_with_encryption_spec",
298+
encryption_spec=encryption_spec,
299+
)
300+
prompt_resource = client.prompts.create(
301+
prompt=TEST_PROMPT,
302+
config=config,
303+
)
304+
assert isinstance(prompt_resource, types.Prompt)
305+
assert isinstance(prompt_resource.dataset, types.Dataset)
306+
307+
# Create a version on a prompt with an encryption spec.
308+
new_prompt = TEST_PROMPT.model_copy(deep=True)
309+
new_prompt.prompt_data.contents[0].parts[0].text = "Is this Alice?"
310+
prompt_version_resource = client.prompts.create_version(
311+
prompt_id=prompt_resource.prompt_id,
312+
prompt=new_prompt,
313+
config=types.CreatePromptVersionConfig(
314+
version_display_name="my_version_existing_dataset",
315+
),
316+
)
317+
assert isinstance(prompt_version_resource, types.Prompt)
318+
assert isinstance(prompt_version_resource.dataset, types.Dataset)
319+
assert isinstance(prompt_version_resource.dataset_version, types.DatasetVersion)
320+
321+
292322
pytestmark = pytest_helper.setup(
293323
file=__file__,
294324
globals_for_file=globals(),

vertexai/_genai/_prompt_management_utils.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
#
1515
"""Utility functions for prompt management."""
1616

17-
from typing import Optional, Union
17+
from typing import Optional
1818

1919
from google.genai import types as genai_types
2020

@@ -123,20 +123,8 @@ def _create_prompt_from_dataset_metadata(
123123

124124
def _raise_for_invalid_prompt(
125125
prompt: types.Prompt,
126-
config: Optional[
127-
Union[types.CreatePromptConfig, types.CreatePromptVersionConfig]
128-
] = None,
129126
) -> None:
130127

131-
if (
132-
isinstance(config, types.CreatePromptConfig)
133-
and config.encryption_spec
134-
and config.prompt_id
135-
):
136-
raise ValueError(
137-
"Encryption spec can only be used for creating new prompts, not for creating new prompt versions."
138-
)
139-
140128
if not prompt.prompt_data:
141129
raise ValueError("Prompt data must be provided.")
142130
if not prompt.prompt_data.contents:

vertexai/_genai/prompts.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -934,7 +934,7 @@ def create(
934934
elif not config:
935935
config = types.CreatePromptConfig()
936936

937-
_prompt_management_utils._raise_for_invalid_prompt(prompt, config)
937+
_prompt_management_utils._raise_for_invalid_prompt(prompt)
938938

939939
prompt_metadata = _prompt_management_utils._create_dataset_metadata_from_prompt(
940940
prompt,
@@ -1001,7 +1001,7 @@ def create_version(
10011001
elif not config:
10021002
config = types.CreatePromptVersionConfig()
10031003

1004-
_prompt_management_utils._raise_for_invalid_prompt(prompt, config)
1004+
_prompt_management_utils._raise_for_invalid_prompt(prompt)
10051005

10061006
if config and config.version_display_name:
10071007
version_name = config.version_display_name
@@ -2093,7 +2093,7 @@ async def create(
20932093
elif not config:
20942094
config = types.CreatePromptConfig()
20952095

2096-
_prompt_management_utils._raise_for_invalid_prompt(prompt, config)
2096+
_prompt_management_utils._raise_for_invalid_prompt(prompt)
20972097

20982098
prompt_metadata = _prompt_management_utils._create_dataset_metadata_from_prompt(
20992099
prompt,
@@ -2160,7 +2160,7 @@ async def create_version(
21602160
elif not config:
21612161
config = types.CreatePromptVersionConfig()
21622162

2163-
_prompt_management_utils._raise_for_invalid_prompt(prompt, config)
2163+
_prompt_management_utils._raise_for_invalid_prompt(prompt)
21642164

21652165
if config and config.version_display_name:
21662166
version_name = config.version_display_name

0 commit comments

Comments
 (0)