Skip to content

Commit 0fe057f

Browse files
authored
docs(discoveryengine): Add Search Tuning Samples (#12788)
* docs(discoveryengine): Add Search Tuning Samples * Add Update Serving config sample * Change search sample to have tuned model as optional parameter
1 parent 3f0b332 commit 0fe057f

File tree

5 files changed

+227
-0
lines changed

5 files changed

+227
-0
lines changed

discoveryengine/search_sample.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ def search_sample(
8282
spell_correction_spec=discoveryengine.SearchRequest.SpellCorrectionSpec(
8383
mode=discoveryengine.SearchRequest.SpellCorrectionSpec.Mode.AUTO
8484
),
85+
# Optional: Use fine-tuned model for this request
86+
# custom_fine_tuning_spec=discoveryengine.CustomFineTuningSpec(
87+
# enable_search_adaptor=True
88+
# ),
8589
)
8690

8791
page_result = client.search(request)
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
16+
# [START genappbuilder_train_custom_model]
17+
18+
from google.api_core.client_options import ClientOptions
19+
from google.api_core.operation import Operation
20+
from google.cloud import discoveryengine
21+
22+
# TODO(developer): Uncomment these variables before running the sample.
23+
# project_id = "YOUR_PROJECT_ID"
24+
# location = "YOUR_LOCATION" # Values: "global"
25+
# data_store_id = "YOUR_DATA_STORE_ID"
26+
# corpus_data_path = "gs://my-bucket/corpus.jsonl"
27+
# query_data_path = "gs://my-bucket/query.jsonl"
28+
# train_data_path = "gs://my-bucket/train.tsv"
29+
# test_data_path = "gs://my-bucket/test.tsv"
30+
31+
32+
def train_custom_model_sample(
33+
project_id: str,
34+
location: str,
35+
data_store_id: str,
36+
corpus_data_path: str,
37+
query_data_path: str,
38+
train_data_path: str,
39+
test_data_path: str,
40+
) -> Operation:
41+
# For more information, refer to:
42+
# https://cloud.google.com/generative-ai-app-builder/docs/locations#specify_a_multi-region_for_your_data_store
43+
client_options = (
44+
ClientOptions(api_endpoint=f"{location}-discoveryengine.googleapis.com")
45+
if location != "global"
46+
else None
47+
)
48+
# Create a client
49+
client = discoveryengine.SearchTuningServiceClient(client_options=client_options)
50+
51+
# The full resource name of the data store
52+
data_store = f"projects/{project_id}/locations/{location}/collections/default_collection/dataStores/{data_store_id}"
53+
54+
# Make the request
55+
operation = client.train_custom_model(
56+
request=discoveryengine.TrainCustomModelRequest(
57+
gcs_training_input=discoveryengine.TrainCustomModelRequest.GcsTrainingInput(
58+
corpus_data_path=corpus_data_path,
59+
query_data_path=query_data_path,
60+
train_data_path=train_data_path,
61+
test_data_path=test_data_path,
62+
),
63+
data_store=data_store,
64+
model_type="search-tuning",
65+
)
66+
)
67+
68+
# Optional: Wait for training to complete
69+
# print(f"Waiting for operation to complete: {operation.operation.name}")
70+
# response = operation.result()
71+
72+
# After the operation is complete,
73+
# get information from operation metadata
74+
# metadata = discoveryengine.TrainCustomModelMetadata(operation.metadata)
75+
76+
# Handle the response
77+
# print(response)
78+
# print(metadata)
79+
print(operation)
80+
81+
return operation
82+
83+
84+
# [END genappbuilder_train_custom_model]
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
16+
import os
17+
18+
from discoveryengine import train_custom_model_sample
19+
from google.api_core.exceptions import AlreadyExists
20+
21+
project_id = os.environ["GOOGLE_CLOUD_PROJECT"]
22+
location = "global"
23+
data_store_id = "tuning-data-store"
24+
corpus_data_path = "gs://cloud-samples-data/gen-app-builder/search-tuning/corpus.jsonl"
25+
query_data_path = "gs://cloud-samples-data/gen-app-builder/search-tuning/query.jsonl"
26+
train_data_path = "gs://cloud-samples-data/gen-app-builder/search-tuning/training.tsv"
27+
test_data_path = "gs://cloud-samples-data/gen-app-builder/search-tuning/test.tsv"
28+
29+
30+
def test_train_custom_model():
31+
try:
32+
operation = train_custom_model_sample.train_custom_model_sample(
33+
project_id,
34+
location,
35+
data_store_id,
36+
corpus_data_path,
37+
query_data_path,
38+
train_data_path,
39+
test_data_path,
40+
)
41+
assert operation
42+
except AlreadyExists:
43+
# Ignore AlreadyExists; training is already in progress.
44+
pass
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
16+
# [START genappbuilder_update_serving_config]
17+
18+
from google.api_core.client_options import ClientOptions
19+
from google.cloud import discoveryengine_v1alpha as discoveryengine
20+
21+
# TODO(developer): Uncomment these variables before running the sample.
22+
# project_id = "YOUR_PROJECT_ID"
23+
# location = "YOUR_LOCATION" # Values: "global"
24+
# engine_id = "YOUR_DATA_STORE_ID"
25+
26+
27+
def update_serving_config_sample(
28+
project_id: str,
29+
location: str,
30+
engine_id: str,
31+
) -> discoveryengine.ServingConfig:
32+
# For more information, refer to:
33+
# https://cloud.google.com/generative-ai-app-builder/docs/locations#specify_a_multi-region_for_your_data_store
34+
client_options = (
35+
ClientOptions(api_endpoint=f"{location}-discoveryengine.googleapis.com")
36+
if location != "global"
37+
else None
38+
)
39+
# Create a client
40+
client = discoveryengine.ServingConfigServiceClient(client_options=client_options)
41+
42+
# The full resource name of the serving config
43+
serving_config_name = f"projects/{project_id}/locations/{location}/collections/default_collection/engines/{engine_id}/servingConfigs/default_search"
44+
45+
update_mask = "customFineTuningSpec.enableSearchAdaptor"
46+
47+
serving_config = client.update_serving_config(
48+
request=discoveryengine.UpdateServingConfigRequest(
49+
serving_config=discoveryengine.ServingConfig(
50+
name=serving_config_name,
51+
custom_fine_tuning_spec=discoveryengine.CustomFineTuningSpec(
52+
enable_search_adaptor=True # Switch to `False` to disable tuned model
53+
),
54+
),
55+
update_mask=update_mask,
56+
)
57+
)
58+
59+
# Handle the response
60+
print(serving_config)
61+
62+
return serving_config
63+
64+
65+
# [END genappbuilder_update_serving_config]
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
16+
import os
17+
18+
from discoveryengine import update_serving_config_sample
19+
20+
project_id = os.environ["GOOGLE_CLOUD_PROJECT"]
21+
location = "global"
22+
engine_id = "tuning-sample-app"
23+
24+
25+
def test_update_serving_config():
26+
serving_config = update_serving_config_sample.update_serving_config_sample(
27+
project_id, location, engine_id
28+
)
29+
assert serving_config
30+
assert serving_config.custom_fine_tuning_spec.enable_search_adaptor

0 commit comments

Comments
 (0)