Skip to content

Commit

Permalink
Fix Spark model logging on passthrough-enabled clusters (mlflow#4549)
Browse files Browse the repository at this point in the history
* Fix spark model logging on passthrough-enabled clusters

* No need for lint comment

* Catch only py4j exceptions
  • Loading branch information
smurching authored Jul 13, 2021
1 parent 35f1b4b commit c991b99
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions mlflow/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def log_model(
model = pipeline.fit(training)
mlflow.spark.log_model(model, "spark-model")
"""
from py4j.protocol import Py4JJavaError
from py4j.protocol import Py4JError

_validate_model(spark_model)
from pyspark.ml import PipelineModel
Expand Down Expand Up @@ -208,7 +208,7 @@ def log_model(
# to persist the model
try:
spark_model.save(posixpath.join(model_dir, _SPARK_MODEL_PATH_SUB))
except Py4JJavaError:
except Py4JError:
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.spark,
Expand Down

0 comments on commit c991b99

Please sign in to comment.