Skip to content

Commit

Permalink
Implementation: pyspark ML autologging for pipeline (mlflow#4263)
Browse files Browse the repository at this point in the history
Implementation: pyspark ML autologging for pipeline
  • Loading branch information
WeichenXu123 authored Apr 23, 2021
1 parent 4a13db1 commit 768d8f1
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 18 deletions.
76 changes: 60 additions & 16 deletions mlflow/pyspark/ml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,17 @@ def _get_warning_msg_for_skip_log_model(model):


def _should_log_model(spark_model):
from pyspark.ml.base import Model

# TODO: Handle PipelineModel/CrossValidatorModel/TrainValidationSplitModel
class_name = _get_fully_qualified_class_name(spark_model)
if class_name in _log_model_allowlist:
if class_name == "pyspark.ml.classification.OneVsRestModel":
return _should_log_model(spark_model.models[0])
elif class_name == "pyspark.ml.pipeline.PipelineModel":
return all(
_should_log_model(stage) for stage in spark_model.stages if isinstance(stage, Model)
)
else:
return True
else:
Expand All @@ -101,8 +107,23 @@ def _get_estimator_info_tags(estimator):
}


def _get_instance_param_map(instance):
def _get_pipeline_stage_hierarchy(pipeline):
from pyspark.ml import Pipeline

stage_hierarchy = []
pipeline_stages = pipeline.getStages()
for stage in pipeline_stages:
if isinstance(stage, Pipeline):
hierarchy_elem = _get_pipeline_stage_hierarchy(stage)
else:
hierarchy_elem = stage.uid
stage_hierarchy.append(hierarchy_elem)
return {pipeline.uid: stage_hierarchy}


def _get_instance_param_map_recursively(instance, level):
from pyspark.ml.param import Params
from pyspark.ml.pipeline import Pipeline
from pyspark.ml.tuning import CrossValidator, TrainValidationSplit

param_map = {
Expand All @@ -111,27 +132,44 @@ def _get_instance_param_map(instance):
if instance.isDefined(param)
}
expanded_param_map = {}
for k, v in param_map.items():
if isinstance(v, Params):
# handle the case param value type inherits `pyspark.ml.param.Params`
# e.g. param like `OneVsRest.classifier`/`CrossValidator.estimator`
expanded_param_map[k] = v.uid
internal_param_map = _get_instance_param_map(v)
for ik, iv in internal_param_map.items():
expanded_param_map[f"{v.uid}.{ik}"] = iv
elif k in ["estimator", "estimatorParamMaps"] and isinstance(
instance, (CrossValidator, TrainValidationSplit)
):

is_pipeline = isinstance(instance, Pipeline)
is_parameter_search_estimator = isinstance(instance, (CrossValidator, TrainValidationSplit))

for param_name, param_value in param_map.items():
if level == 0:
logged_param_name = param_name
else:
logged_param_name = f"{instance.uid}.{param_name}"

if is_pipeline and param_name == "stages":
expanded_param_map[logged_param_name] = _get_pipeline_stage_hierarchy(instance)[
instance.uid
]
for stage in instance.getStages():
stage_param_map = _get_instance_param_map_recursively(stage, level + 1)
expanded_param_map.update(stage_param_map)
elif is_parameter_search_estimator and param_name in ["estimator", "estimatorParamMaps"]:
# skip log estimator Param and its nested params because they will be
# logged in nested runs.
# TODO: Log `estimatorParamMaps` as JSON artifacts.
pass
elif isinstance(param_value, Params):
# handle the case param value type inherits `pyspark.ml.param.Params`
# e.g. param like `OneVsRest.classifier`/`CrossValidator.estimator`
expanded_param_map[logged_param_name] = param_value.uid
internal_param_map = _get_instance_param_map_recursively(param_value, level + 1)
expanded_param_map.update(internal_param_map)
else:
expanded_param_map[k] = v
expanded_param_map[logged_param_name] = param_value

return expanded_param_map


def _get_instance_param_map(instance):
return _get_instance_param_map_recursively(instance, level=0)


def _get_warning_msg_for_fit_call_with_a_list_of_params(estimator):
return (
"Skip pyspark ML autologging when calling "
Expand Down Expand Up @@ -208,6 +246,7 @@ def autolog(
MAX_ENTITY_KEY_LENGTH,
)
from pyspark.ml.base import Estimator
from pyspark.ml import Pipeline

global _log_model_allowlist

Expand All @@ -218,10 +257,15 @@ def _log_pretraining_metadata(estimator, params):
if params and isinstance(params, dict):
estimator = estimator.copy(params)

param_map = _get_instance_param_map(estimator)
if isinstance(estimator, Pipeline):
pipeline_hierarchy = _get_pipeline_stage_hierarchy(estimator)
try_mlflow_log(
mlflow.log_dict, pipeline_hierarchy, artifact_file="pipeline_hierarchy.json"
)

# Chunk model parameters to avoid hitting the log_batch API limit
for chunk in _chunk_dict(
_get_instance_param_map(estimator), chunk_size=MAX_PARAMS_TAGS_PER_BATCH,
):
for chunk in _chunk_dict(param_map, chunk_size=MAX_PARAMS_TAGS_PER_BATCH,):
truncated = _truncate_dict(chunk, MAX_ENTITY_KEY_LENGTH, MAX_PARAM_VAL_LENGTH)
try_mlflow_log(mlflow.log_params, truncated)

Expand Down
3 changes: 3 additions & 0 deletions mlflow/pyspark/ml/log_model_allowlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,6 @@ pyspark.ml.feature.UnivariateFeatureSelectorModel

# composite model
pyspark.ml.classification.OneVsRestModel

# pipeline model
pyspark.ml.pipeline.PipelineModel
113 changes: 111 additions & 2 deletions tests/spark_autologging/ml/test_pyspark_ml_autologging.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@
MultilayerPerceptronClassifier,
OneVsRest,
)
from pyspark.ml.feature import HashingTF, Tokenizer
from pyspark.ml import Pipeline
from mlflow.pyspark.ml import (
_should_log_model,
_get_instance_param_map,
_get_warning_msg_for_skip_log_model,
_get_warning_msg_for_fit_call_with_a_list_of_params,
_get_pipeline_stage_hierarchy,
)

pytestmark = pytest.mark.large
Expand Down Expand Up @@ -57,6 +60,19 @@ def dataset_multinomial(spark_session):
)


@pytest.fixture(scope="module")
def dataset_text(spark_session):
return spark_session.createDataFrame(
[
(0, "a b c d e spark", 1.0),
(1, "b d", 0.0),
(2, "spark f g h", 1.0),
(3, "hadoop mapreduce", 0.0),
],
["id", "text", "label"],
)


def truncate_param_dict(d):
return _truncate_dict(d, MAX_ENTITY_KEY_LENGTH, MAX_PARAM_VAL_LENGTH)

Expand Down Expand Up @@ -202,7 +218,7 @@ def test_fit_with_a_list_of_params(dataset_binomial):
mock_set_tags.assert_not_called()


def test_should_log_model(dataset_binomial, dataset_multinomial):
def test_should_log_model(dataset_binomial, dataset_multinomial, dataset_text):
mlflow.pyspark.ml.autolog(log_models=True)
lor = LogisticRegression()

Expand All @@ -213,9 +229,24 @@ def test_should_log_model(dataset_binomial, dataset_multinomial):
ova1_model = ova1.fit(dataset_multinomial)
assert _should_log_model(ova1_model)

tokenizer = Tokenizer(inputCol="text", outputCol="words")
hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features")
lr = LogisticRegression(maxIter=2)
pipeline = Pipeline(stages=[tokenizer, hashingTF, lr])
pipeline_model = pipeline.fit(dataset_text)
assert _should_log_model(pipeline_model)

nested_pipeline = Pipeline(stages=[tokenizer, Pipeline(stages=[hashingTF, lr])])
nested_pipeline_model = nested_pipeline.fit(dataset_text)
assert _should_log_model(nested_pipeline_model)

with mock.patch(
"mlflow.pyspark.ml._log_model_allowlist",
{"pyspark.ml.regression.LinearRegressionModel", "pyspark.ml.classification.OneVsRestModel"},
{
"pyspark.ml.regression.LinearRegressionModel",
"pyspark.ml.classification.OneVsRestModel",
"pyspark.ml.pipeline.PipelineModel",
},
), mock.patch("mlflow.pyspark.ml._logger.warning") as mock_warning:
lr = LinearRegression()
lr_model = lr.fit(dataset_binomial)
Expand All @@ -224,6 +255,8 @@ def test_should_log_model(dataset_binomial, dataset_multinomial):
assert not _should_log_model(lor_model)
mock_warning.called_once_with(_get_warning_msg_for_skip_log_model(lor_model))
assert not _should_log_model(ova1_model)
assert not _should_log_model(pipeline_model)
assert not _should_log_model(nested_pipeline_model)


def test_param_map_captures_wrapped_params(dataset_binomial):
Expand All @@ -245,3 +278,79 @@ def test_param_map_captures_wrapped_params(dataset_binomial):
assert run_data.params == truncate_param_dict(
stringify_dict_values(_get_instance_param_map(ova))
)


def test_pipeline(dataset_text):
mlflow.pyspark.ml.autolog()

tokenizer = Tokenizer(inputCol="text", outputCol="words")
hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features")
lr = LogisticRegression(maxIter=2, regParam=0.001)
pipeline = Pipeline(stages=[tokenizer, hashingTF, lr])
inner_pipeline = Pipeline(stages=[hashingTF, lr])
nested_pipeline = Pipeline(stages=[tokenizer, inner_pipeline])

assert _get_pipeline_stage_hierarchy(pipeline) == {
pipeline.uid: [tokenizer.uid, hashingTF.uid, lr.uid]
}
assert _get_pipeline_stage_hierarchy(nested_pipeline) == {
nested_pipeline.uid: [tokenizer.uid, {inner_pipeline.uid: [hashingTF.uid, lr.uid]}]
}

for estimator in [pipeline, nested_pipeline]:
with mlflow.start_run() as run:
model = estimator.fit(dataset_text)

run_id = run.info.run_id
run_data = get_run_data(run_id)
assert run_data.params == truncate_param_dict(
stringify_dict_values(_get_instance_param_map(estimator))
)
assert run_data.tags == get_expected_class_tags(estimator)
assert MODEL_DIR in run_data.artifacts
loaded_model = load_model_by_run_id(run_id)
assert loaded_model.uid == model.uid
assert run_data.artifacts == ["model", "pipeline_hierarchy.json"]


def test_get_instance_param_map(spark_session): # pylint: disable=unused-argument
lor = LogisticRegression(maxIter=3, standardization=False)
lor_params = _get_instance_param_map(lor)
assert (
lor_params["maxIter"] == 3
and not lor_params["standardization"]
and lor_params["family"] == lor.getOrDefault(lor.family)
)

ova = OneVsRest(classifier=lor, labelCol="abcd")
ova_params = _get_instance_param_map(ova)
assert (
ova_params["classifier"] == lor.uid
and ova_params["labelCol"] == "abcd"
and ova_params[f"{lor.uid}.maxIter"] == 3
and ova_params[f"{lor.uid}.family"] == lor.getOrDefault(lor.family)
)

tokenizer = Tokenizer(inputCol="text", outputCol="words")
hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features")
pipeline = Pipeline(stages=[tokenizer, hashingTF, ova])
inner_pipeline = Pipeline(stages=[hashingTF, ova])
nested_pipeline = Pipeline(stages=[tokenizer, inner_pipeline])

pipeline_params = _get_instance_param_map(pipeline)
nested_pipeline_params = _get_instance_param_map(nested_pipeline)

assert pipeline_params["stages"] == [tokenizer.uid, hashingTF.uid, ova.uid]
assert nested_pipeline_params["stages"] == [
tokenizer.uid,
{inner_pipeline.uid: [hashingTF.uid, ova.uid]},
]

for params_to_test in [pipeline_params, nested_pipeline_params]:
assert (
params_to_test[f"{tokenizer.uid}.inputCol"] == "text"
and params_to_test[f"{tokenizer.uid}.outputCol"] == "words"
)
assert params_to_test[f"{hashingTF.uid}.outputCol"] == "features"
assert params_to_test[f"{ova.uid}.classifier"] == lor.uid
assert params_to_test[f"{lor.uid}.maxIter"] == 3

0 comments on commit 768d8f1

Please sign in to comment.