@@ -116,6 +116,20 @@ def get_dataset_mock():
116
116
yield get_dataset_mock
117
117
118
118
119
+ @pytest .fixture
120
+ def get_dataset_template_config_mock ():
121
+ with mock .patch .object (
122
+ dataset_service .DatasetServiceClient , "get_dataset"
123
+ ) as get_dataset_mock :
124
+ get_dataset_mock .return_value = gca_dataset .Dataset (
125
+ display_name = _TEST_DISPLAY_NAME ,
126
+ metadata_schema_uri = _TEST_METADATA_SCHEMA_URI_MULTIMODAL ,
127
+ name = _TEST_NAME ,
128
+ metadata = _TEST_METADATA_MULTIMODAL_WITH_TEMPLATE_CONFIG ,
129
+ )
130
+ yield get_dataset_mock
131
+
132
+
119
133
@pytest .fixture
120
134
def get_dataset_request_column_name_mock ():
121
135
with mock .patch .object (
@@ -756,6 +770,60 @@ def test_assess_tuning_validity_request_column_name(
756
770
timeout = None ,
757
771
)
758
772
773
+ @pytest .mark .usefixtures ("get_dataset_template_config_mock" )
774
+ def test_assess_tuning_validity_uses_attached_template_config (
775
+ self , assess_data_tuning_validation_mock
776
+ ):
777
+ aiplatform .init (project = _TEST_PROJECT )
778
+ dataset = ummd .MultimodalDataset (dataset_name = _TEST_NAME )
779
+ dataset .assess_tuning_validity (
780
+ model_name = "gemini-1.5-flash-exp" ,
781
+ dataset_usage = "SFT_TRAINING" ,
782
+ )
783
+ assess_data_tuning_validation_mock .assert_called_once_with (
784
+ request = gca_dataset_service .AssessDataRequest (
785
+ name = _TEST_NAME ,
786
+ tuning_validation_assessment_config = gca_dataset_service .AssessDataRequest .TuningValidationAssessmentConfig (
787
+ model_name = "gemini-1.5-flash-exp" ,
788
+ dataset_usage = gca_dataset_service .AssessDataRequest .TuningValidationAssessmentConfig .DatasetUsage .SFT_TRAINING ,
789
+ ),
790
+ gemini_request_read_config = gca_dataset_service .GeminiRequestReadConfig (
791
+ template_config = _TEST_METADATA_MULTIMODAL_WITH_TEMPLATE_CONFIG [
792
+ "geminiTemplateConfigSource"
793
+ ]["geminiTemplateConfig" ]
794
+ ),
795
+ ),
796
+ timeout = None ,
797
+ )
798
+
799
+ @pytest .mark .usefixtures ("get_dataset_request_column_name_mock" )
800
+ def test_assess_tuning_validity_request_column_name_overridden_by_template_config (
801
+ self , assess_data_tuning_validation_mock
802
+ ):
803
+ aiplatform .init (project = _TEST_PROJECT )
804
+ dataset = ummd .MultimodalDataset (dataset_name = _TEST_NAME )
805
+ template_config = ummd .GeminiTemplateConfig (
806
+ field_mapping = {"question" : "questionColumn" },
807
+ )
808
+ dataset .assess_tuning_validity (
809
+ model_name = "gemini-1.5-flash-exp" ,
810
+ dataset_usage = "SFT_TRAINING" ,
811
+ template_config = template_config ,
812
+ )
813
+ assess_data_tuning_validation_mock .assert_called_once_with (
814
+ request = gca_dataset_service .AssessDataRequest (
815
+ name = _TEST_NAME ,
816
+ tuning_validation_assessment_config = gca_dataset_service .AssessDataRequest .TuningValidationAssessmentConfig (
817
+ model_name = "gemini-1.5-flash-exp" ,
818
+ dataset_usage = gca_dataset_service .AssessDataRequest .TuningValidationAssessmentConfig .DatasetUsage .SFT_TRAINING ,
819
+ ),
820
+ gemini_request_read_config = gca_dataset_service .GeminiRequestReadConfig (
821
+ template_config = template_config ._raw_gemini_template_config
822
+ ),
823
+ ),
824
+ timeout = None ,
825
+ )
826
+
759
827
@pytest .mark .usefixtures ("get_dataset_mock" )
760
828
def test_assess_tuning_validity_invalid_dataset_usage_throws_error (
761
829
self , assess_data_tuning_validation_mock
0 commit comments