@@ -360,7 +360,7 @@ def test_write_session_mismatched_id(example_database: Database, example_subject
360360 (["an_existing_target_id" ], [0 ], pytest .raises (ValueError , match = "provides no transforms" )),
361361 ]
362362)
363- def test_write_session_with_invalid_fit_approval (
363+ def test_write_session_with_invalid_fit_results (
364364 example_database : Database ,
365365 example_subject : Subject ,
366366 target_ids : List [str ],
@@ -384,6 +384,13 @@ def test_write_session_with_invalid_fit_approval(
384384 with expectation :
385385 example_database .write_session (example_subject , session )
386386
387+ def test_session_arrays_read_correctly (example_session :Session ):
388+ """Verify that session data that is supposed to be array type is actually array type after reading from json"""
389+ assert isinstance (example_session .array_transform .matrix , np .ndarray )
390+ for _ , (_ , array_transforms ) in example_session .virtual_fit_results .items ():
391+ for array_transform in array_transforms :
392+ assert isinstance (array_transform .matrix , np .ndarray )
393+
387394@pytest .mark .parametrize ("compact_representation" , [True , False ])
388395def test_serialize_deserialize_session (example_session : Session , compact_representation :bool ):
389396 reconstructed_session = example_session .from_json (example_session .to_json (compact_representation ))
0 commit comments