Skip to content

Commit 8106346

Browse files
cleop-googlecopybara-github
authored andcommitted
chore: fix resolving between passed and attached read configs in multimodal datasets
PiperOrigin-RevId: 758655464
1 parent 8c0bf19 commit 8106346

File tree

2 files changed

+82
-20
lines changed

2 files changed

+82
-20
lines changed

google/cloud/aiplatform/preview/datasets.py

+14-20
Original file line numberDiff line numberDiff line change
@@ -1511,30 +1511,24 @@ def _build_assess_data_request(
15111511
)
15121512

15131513
def _build_gemini_request_read_config(
1514-
self, template_config: Optional[GeminiTemplateConfig] = None
1515-
):
1516-
if self.request_column_name is not None:
1514+
self, provided_template_config: Optional[GeminiTemplateConfig] = None
1515+
) -> gca_dataset_service.GeminiRequestReadConfig:
1516+
"""Returns the provided template config wrapped in a read config if it
1517+
is not None, otherwise returns the read config attached to the
1518+
dataset."""
1519+
if provided_template_config is not None:
15171520
return gca_dataset_service.GeminiRequestReadConfig(
1518-
assembled_request_column_name=self.request_column_name
1521+
template_config=provided_template_config._raw_gemini_template_config
15191522
)
1520-
else:
1521-
template_config_to_use = self._resolve_template_config(template_config)
1523+
elif self.template_config is not None:
15221524
return gca_dataset_service.GeminiRequestReadConfig(
1523-
template_config=template_config_to_use._raw_gemini_template_config
1525+
template_config=self.template_config._raw_gemini_template_config
1526+
)
1527+
elif self.request_column_name is not None:
1528+
return gca_dataset_service.GeminiRequestReadConfig(
1529+
assembled_request_column_name=self.request_column_name
15241530
)
1525-
1526-
def _resolve_template_config(
1527-
self,
1528-
template_config: Optional[GeminiTemplateConfig] = None,
1529-
) -> GeminiTemplateConfig:
1530-
"""Returns the passed template config if it is not None, otherwise
1531-
returns the template config attached to the dataset.
1532-
"""
1533-
if template_config is not None:
1534-
return template_config
1535-
elif self.template_config is not None:
1536-
return self.template_config
15371531
else:
15381532
raise ValueError(
1539-
"No template config was passed or attached to the dataset."
1533+
"No template config was provided and no read config is attached to the dataset."
15401534
)

tests/unit/aiplatform/test_multimodal_datasets.py

+68
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,20 @@ def get_dataset_mock():
116116
yield get_dataset_mock
117117

118118

119+
@pytest.fixture
120+
def get_dataset_template_config_mock():
121+
with mock.patch.object(
122+
dataset_service.DatasetServiceClient, "get_dataset"
123+
) as get_dataset_mock:
124+
get_dataset_mock.return_value = gca_dataset.Dataset(
125+
display_name=_TEST_DISPLAY_NAME,
126+
metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_MULTIMODAL,
127+
name=_TEST_NAME,
128+
metadata=_TEST_METADATA_MULTIMODAL_WITH_TEMPLATE_CONFIG,
129+
)
130+
yield get_dataset_mock
131+
132+
119133
@pytest.fixture
120134
def get_dataset_request_column_name_mock():
121135
with mock.patch.object(
@@ -756,6 +770,60 @@ def test_assess_tuning_validity_request_column_name(
756770
timeout=None,
757771
)
758772

773+
@pytest.mark.usefixtures("get_dataset_template_config_mock")
774+
def test_assess_tuning_validity_uses_attached_template_config(
775+
self, assess_data_tuning_validation_mock
776+
):
777+
aiplatform.init(project=_TEST_PROJECT)
778+
dataset = ummd.MultimodalDataset(dataset_name=_TEST_NAME)
779+
dataset.assess_tuning_validity(
780+
model_name="gemini-1.5-flash-exp",
781+
dataset_usage="SFT_TRAINING",
782+
)
783+
assess_data_tuning_validation_mock.assert_called_once_with(
784+
request=gca_dataset_service.AssessDataRequest(
785+
name=_TEST_NAME,
786+
tuning_validation_assessment_config=gca_dataset_service.AssessDataRequest.TuningValidationAssessmentConfig(
787+
model_name="gemini-1.5-flash-exp",
788+
dataset_usage=gca_dataset_service.AssessDataRequest.TuningValidationAssessmentConfig.DatasetUsage.SFT_TRAINING,
789+
),
790+
gemini_request_read_config=gca_dataset_service.GeminiRequestReadConfig(
791+
template_config=_TEST_METADATA_MULTIMODAL_WITH_TEMPLATE_CONFIG[
792+
"geminiTemplateConfigSource"
793+
]["geminiTemplateConfig"]
794+
),
795+
),
796+
timeout=None,
797+
)
798+
799+
@pytest.mark.usefixtures("get_dataset_request_column_name_mock")
800+
def test_assess_tuning_validity_request_column_name_overridden_by_template_config(
801+
self, assess_data_tuning_validation_mock
802+
):
803+
aiplatform.init(project=_TEST_PROJECT)
804+
dataset = ummd.MultimodalDataset(dataset_name=_TEST_NAME)
805+
template_config = ummd.GeminiTemplateConfig(
806+
field_mapping={"question": "questionColumn"},
807+
)
808+
dataset.assess_tuning_validity(
809+
model_name="gemini-1.5-flash-exp",
810+
dataset_usage="SFT_TRAINING",
811+
template_config=template_config,
812+
)
813+
assess_data_tuning_validation_mock.assert_called_once_with(
814+
request=gca_dataset_service.AssessDataRequest(
815+
name=_TEST_NAME,
816+
tuning_validation_assessment_config=gca_dataset_service.AssessDataRequest.TuningValidationAssessmentConfig(
817+
model_name="gemini-1.5-flash-exp",
818+
dataset_usage=gca_dataset_service.AssessDataRequest.TuningValidationAssessmentConfig.DatasetUsage.SFT_TRAINING,
819+
),
820+
gemini_request_read_config=gca_dataset_service.GeminiRequestReadConfig(
821+
template_config=template_config._raw_gemini_template_config
822+
),
823+
),
824+
timeout=None,
825+
)
826+
759827
@pytest.mark.usefixtures("get_dataset_mock")
760828
def test_assess_tuning_validity_invalid_dataset_usage_throws_error(
761829
self, assess_data_tuning_validation_mock

0 commit comments

Comments
 (0)