diff --git a/tests/test_basic.py b/tests/test_basic.py index d71d9e1..c21d5c1 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -258,3 +258,28 @@ def test_update_response_poly_categorical(): ad.update_response_poly_categorical(predictor_name='X6', betas={'Red': -2000, 'Blue': -1700}) assert ad.predictor_matrix.loc[1, 'X6'] == 'Red' assert ad.response_vector[1] < -1900 + +def test_catg_realistic(): + """Test function 'update_predictor_catg_realistic' + Test logic: + length is correct. + New values are of string format. + (Need to be updated with more sophisticated logic) + """ + ## Initialize and use the function to update + ad = AnalyticsDataframe(1000, 3, ["xx1", "xx2", "xx3"], "yy") + ad.update_predictor_catg_realistic("xx1", "name") + ad.update_predictor_catg_realistic("xx2", "address") + pred_matrix = ad.predictor_matrix + + ## Test if the length is correct + assert len(pred_matrix["xx1"]) == 1000 + assert len(pred_matrix["xx2"]) == 1000 + ## Test if the type is string + assert isinstance(pred_matrix["xx1"][0], str) + assert isinstance(pred_matrix["xx2"][0], str) + + ## Test its error cases + # pred_exists error: predictor name doesn't exist + with pytest.raises(KeyError): + ad.update_predictor_catg_realistic("random_name", "name") \ No newline at end of file