Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Calling Java Function in Python Executor and ModelBroadcast in Python #2284

Merged
merged 12 commits into from
Feb 9, 2018
Merged
Empty file.
72 changes: 72 additions & 0 deletions pyspark/bigdl/models/utils/model_broadcast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import os
import sys
import gc
from tempfile import NamedTemporaryFile

from pyspark.cloudpickle import print_exec
from pyspark.broadcast import Broadcast
from pyspark.broadcast import _from_id
from bigdl.nn.layer import Model

def _from_id_and_type(bid, bigdl_type):
result = _from_id(bid)
return ModelBroadcast(path=result._path, bigdl_type=bigdl_type)

def broadcastModel(sc, layer):
return ModelBroadcast(sc, layer, sc._pickled_broadcast_vars)

class ModelBroadcast(Broadcast):

def __init__(self, sc=None, layer=None, pickle_registry=None, path=None, bigdl_type="float"):
"""
Should not be called directly by users -- use L{SparkContext.broadcast()}
instead.
"""
if layer is not None:
self.bigdl_type = layer.bigdl_type
else:
self.bigdl_type = bigdl_type
super(ModelBroadcast, self).__init__(sc, layer, pickle_registry, path)

def dump(self, value, f):
try:
value.saveModel(f.name, over_write=True)
except Exception as e:
msg = "Could not serialize broadcast: %s" % e.__class__.__name__
print_exec(sys.stderr)
raise ValueError(msg)
f.close()
return f.name

def _load(self, path):
return Model.loadModel(path, bigdl_type=self.bigdl_type)

@property
def value(self):
""" Return the broadcasted value
"""
if not hasattr(self, "_value") and self._path is not None:
self._value = self._load(self._path)
return self._value

def __reduce__(self):
if self._jbroadcast is None:
raise Exception("Broadcast can only be serialized in driver")
self._pickle_registry.add(self)
return _from_id_and_type, (self._jbroadcast.id(), self.bigdl_type)
18 changes: 9 additions & 9 deletions pyspark/bigdl/nn/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ def element(self):
return Layer.of(self.value.element())

def remove_pre_edges(self):
callJavaFunc(get_spark_context(), self.value.removePreEdges)
callJavaFunc(self.value.removePreEdges)

def remove_next_edges(self):
callJavaFunc(get_spark_context(), self.value.removeNextEdges)
callJavaFunc(self.value.removeNextEdges)



Expand Down Expand Up @@ -130,14 +130,14 @@ def set_name(self, name):
Give this model a name. There would be a generated name
consist of class name and UUID if user doesn't set it.
"""
callJavaFunc(get_spark_context(), self.value.setName, name)
callJavaFunc(self.value.setName, name)
return self

def name(self):
"""
Name of this layer
"""
return callJavaFunc(get_spark_context(), self.value.getName)
return callJavaFunc(self.value.getName)

def set_seed(self, seed=123):
"""
Expand Down Expand Up @@ -230,7 +230,7 @@ def zero_grad_parameters(self):
If the module has parameters, this will zero the accumulation of the gradients with respect
to these parameters. Otherwise, it does nothing.
"""
callJavaFunc(get_spark_context(), self.value.zeroGradParameters)
callJavaFunc(self.value.zeroGradParameters)

def update_parameters(self, learning_rate):
"""
Expand All @@ -245,7 +245,7 @@ def reset(self):
"""
Initialize the model weights.
"""
callJavaFunc(get_spark_context(), self.value.reset)
callJavaFunc(self.value.reset)
return self

def parameters(self):
Expand Down Expand Up @@ -528,9 +528,9 @@ def training(self, is_training=True):
Set this layer in the training mode or in predition mode if is_training=False
'''
if is_training:
callJavaFunc(get_spark_context(), self.value.training)
callJavaFunc(self.value.training)
else:
callJavaFunc(get_spark_context(), self.value.evaluate)
callJavaFunc(self.value.evaluate)
return self

def is_training(self):
Expand All @@ -546,7 +546,7 @@ def is_training(self):
>>> layer.is_training()
True
'''
return callJavaFunc(get_spark_context(), self.value.isTraining)
return callJavaFunc(self.value.isTraining)

def quantize(self):
'''
Expand Down
2 changes: 1 addition & 1 deletion pyspark/bigdl/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,7 @@ def optimize(self):
"""
Do an optimization.
"""
jmodel = callJavaFunc(get_spark_context(), self.value.optimize)
jmodel = callJavaFunc(self.value.optimize)
from bigdl.nn.layer import Layer
return Layer.of(jmodel)

Expand Down
76 changes: 51 additions & 25 deletions pyspark/bigdl/util/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,17 @@

import os
import sys
import glob
from py4j.protocol import Py4JJavaError
from py4j.java_gateway import JavaObject
from py4j.java_collections import ListConverter, JavaArray, JavaList, JavaMap, MapConverter
from py4j.java_gateway import JavaGateway, GatewayClient

from pyspark import RDD, SparkContext
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
from pyspark.sql import DataFrame, SQLContext
from pyspark.mllib.common import callJavaFunc
from pyspark import SparkConf
from pyspark.files import SparkFiles
import numpy as np
import threading
import tempfile
Expand All @@ -46,13 +47,18 @@ class SingletonMixin(object):

@classmethod
def instance(cls,
bigdl_type="float"):
bigdl_type, *args):
if not cls._instance:
with cls._lock:
if not cls._instance:
cls._instance = cls(bigdl_type)
cls._instance = cls(bigdl_type, *args)
return cls._instance

class GatewayWrapper(SingletonMixin):

def __init__(self, bigdl_type, port=25333):
self.value = JavaGateway(GatewayClient(port=port), auto_convert=True)


class JavaCreator(SingletonMixin):
__creator_class=["com.intel.analytics.bigdl.python.api.PythonBigDLKeras"]
Expand All @@ -74,11 +80,10 @@ def set_creator_class(cls, cclass):
JavaCreator.__creator_class = cclass
JavaCreator._instance = None

def __init__(self, bigdl_type):
sc = get_spark_context()
def __init__(self, bigdl_type, gateway):
self.value = []
for creator_class in JavaCreator.get_creator_class():
jclass = getattr(sc._jvm, creator_class)
jclass = getattr(gateway.jvm, creator_class)
if bigdl_type == "float":
self.value.append(getattr(jclass, "ofFloat")())
elif bigdl_type == "double":
Expand Down Expand Up @@ -437,6 +442,9 @@ def uniform(self, a, b, size):
def init_engine(bigdl_type="float"):
callBigDlFunc(bigdl_type, "initEngine")

def init_executor_gateway(sc, bigdl_type="float"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when call this method?

callBigDlFunc(bigdl_type, "initExecutorGateway", sc, sc._gateway._gateway_client.port)


def redire_spark_logs(bigdl_type="float", log_path=os.getcwd()+"/bigdl.log"):
"""
Expand Down Expand Up @@ -556,16 +564,33 @@ def get_spark_sql_context(sc):
else:
return SQLContext(sc) # Compatible with Spark1.5.1

def _get_port():
root_dir = SparkFiles.getRootDirectory()
path = os.path.join(root_dir, "gateway_port")
f = open(path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if this fails?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll submit another PR and make it report a meaningful error message.

port = int(f.readline())
return port

def _get_gateway():
if SparkFiles._is_running_on_worker:
gateway_port = _get_port()
gateway = GatewayWrapper.instance(None, gateway_port).value
else:
sc = get_spark_context()
gateway = sc._gateway
return gateway


def callBigDlFunc(bigdl_type, name, *args):
""" Call API in PythonBigDL """
sc = get_spark_context()
gateway = _get_gateway()
error = Exception("Cannot find function: %s" % name)
for jinvoker in JavaCreator.instance(bigdl_type=bigdl_type).value:
for jinvoker in JavaCreator.instance(bigdl_type, gateway).value:
# hasattr(jinvoker, name) always return true here,
# so you need to invoke the method to check if it exist or not
try:
api = getattr(jinvoker, name)
result = callJavaFunc(sc, api, *args)
result = callJavaFunc(api, *args)
except Exception as e:
error = e
if "does not exist" not in str(e):
Expand All @@ -575,7 +600,7 @@ def callBigDlFunc(bigdl_type, name, *args):
raise error


def _java2py(sc, r, encoding="bytes"):
def _java2py(gateway, r, encoding="bytes"):
if isinstance(r, JavaObject):
clsName = r.getClass().getSimpleName()
# convert RDD into JavaRDD
Expand All @@ -584,20 +609,20 @@ def _java2py(sc, r, encoding="bytes"):
clsName = 'JavaRDD'

if clsName == 'JavaRDD':
jrdd = sc._jvm.SerDe.javaToPython(r)
return RDD(jrdd, sc)
jrdd = gateway.jvm.SerDe.javaToPython(r)
return RDD(jrdd, get_spark_context())

if clsName == 'DataFrame':
return DataFrame(r, get_spark_sql_context(sc))
return DataFrame(r, get_spark_sql_context(get_spark_context()))

if clsName == 'Dataset':
return DataFrame(r, get_spark_sql_context(sc))
return DataFrame(r, get_spark_sql_context(get_spark_context()))

if clsName in _picklable_classes:
r = sc._jvm.org.apache.spark.bigdl.api.python.BigDLSerDe.dumps(r)
r = gateway.jvm.org.apache.spark.bigdl.api.python.BigDLSerDe.dumps(r)
elif isinstance(r, (JavaArray, JavaList, JavaMap)):
try:
r = sc._jvm.org.apache.spark.bigdl.api.python.BigDLSerDe.dumps(
r = gateway.jvm.org.apache.spark.bigdl.api.python.BigDLSerDe.dumps(
r)
except Py4JJavaError:
pass # not pickable
Expand All @@ -607,11 +632,12 @@ def _java2py(sc, r, encoding="bytes"):
return r


def callJavaFunc(sc, func, *args):
def callJavaFunc(func, *args):
""" Call Java Function """
args = [_py2java(sc, a) for a in args]
gateway = _get_gateway()
args = [_py2java(gateway, a) for a in args]
result = func(*args)
return _java2py(sc, result)
return _java2py(gateway, result)


def _to_java_object_rdd(rdd):
Expand All @@ -627,7 +653,7 @@ def _to_java_object_rdd(rdd):
rdd._jrdd, True)


def _py2java(sc, obj):
def _py2java(gateway, obj):
""" Convert Python object into Java """
if isinstance(obj, RDD):
obj = _to_java_object_rdd(obj)
Expand All @@ -636,13 +662,13 @@ def _py2java(sc, obj):
elif isinstance(obj, SparkContext):
obj = obj._jsc
elif isinstance(obj, (list, tuple)):
obj = ListConverter().convert([_py2java(sc, x) for x in obj],
sc._gateway._gateway_client)
obj = ListConverter().convert([_py2java(gateway, x) for x in obj],
gateway._gateway_client)
elif isinstance(obj, dict):
result = {}
for (key, value) in obj.items():
result[key] = _py2java(sc, value)
obj = MapConverter().convert(result, sc._gateway._gateway_client)
result[key] = _py2java(gateway, value)
obj = MapConverter().convert(result, gateway._gateway_client)
elif isinstance(obj, JavaValue):
obj = obj.value
elif isinstance(obj, JavaObject):
Expand All @@ -651,7 +677,7 @@ def _py2java(sc, obj):
pass
else:
data = bytearray(PickleSerializer().dumps(obj))
obj = sc._jvm.org.apache.spark.bigdl.api.python.BigDLSerDe.loads(data)
obj = gateway.jvm.org.apache.spark.bigdl.api.python.BigDLSerDe.loads(data)
return obj


Expand Down
15 changes: 14 additions & 1 deletion pyspark/test/bigdl/test_simple_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from numpy.testing import assert_allclose, assert_array_equal
from bigdl.util.engine import compare_version
from bigdl.transform.vision.image import *
from bigdl.models.utils.model_broadcast import broadcastModel
np.random.seed(1337) # for reproducibility


Expand Down Expand Up @@ -533,7 +534,7 @@ def test_save_jtensor_dict(self):
tensors["tensor1"] = JTensor.from_ndarray(np.random.rand(3, 2))
tensors["tensor2"] = JTensor.from_ndarray(np.random.rand(3, 2))
# in old impl, this will throw an exception
_py2java(self.sc, tensors)
_py2java(self.sc._gateway, tensors)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add unittest for the new added ModelBroadcast?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added


def test_compare_version(self):
assert compare_version("2.1.1", "2.2.0") == -1
Expand Down Expand Up @@ -601,5 +602,17 @@ def test_local_predict_multiple_input(self):
JTensor.from_ndarray(np.ones([4, 3]))])
assert result4.shape == (4,)

def test_model_broadcast(self):

init_executor_gateway(self.sc)
model = Linear(3, 2)
broadcasted = broadcastModel(self.sc, model)
input_data = np.random.rand(3)
output = self.sc.parallelize([input_data], 1)\
.map(lambda x: broadcasted.value.forward(x)).first()
expected = model.forward(input_data)

assert_allclose(output, expected)

if __name__ == "__main__":
pytest.main([__file__])
Loading