diff --git a/mlflow/spark.py b/mlflow/spark.py index 55cd85215081e..c8912ff3ed265 100644 --- a/mlflow/spark.py +++ b/mlflow/spark.py @@ -176,7 +176,7 @@ def log_model( model = pipeline.fit(training) mlflow.spark.log_model(model, "spark-model") """ - from py4j.protocol import Py4JJavaError + from py4j.protocol import Py4JError _validate_model(spark_model) from pyspark.ml import PipelineModel @@ -208,7 +208,7 @@ def log_model( # to persist the model try: spark_model.save(posixpath.join(model_dir, _SPARK_MODEL_PATH_SUB)) - except Py4JJavaError: + except Py4JError: return Model.log( artifact_path=artifact_path, flavor=mlflow.spark,