-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support Calling Java Function in Python Executor and ModelBroadcast in Python #2284
Changes from all commits
bb38cc2
c81a36e
b141b53
1af41ae
d851be4
cf3e1c2
9ef1545
b142440
aad927e
e1c4ffe
b9c3e5a
e6658a2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# | ||
# Copyright 2016 The BigDL Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
import os | ||
import sys | ||
import gc | ||
from tempfile import NamedTemporaryFile | ||
|
||
from pyspark.cloudpickle import print_exec | ||
from pyspark.broadcast import Broadcast | ||
from pyspark.broadcast import _from_id | ||
from bigdl.nn.layer import Model | ||
|
||
def _from_id_and_type(bid, bigdl_type): | ||
result = _from_id(bid) | ||
return ModelBroadcast(path=result._path, bigdl_type=bigdl_type) | ||
|
||
def broadcastModel(sc, layer): | ||
return ModelBroadcast(sc, layer, sc._pickled_broadcast_vars) | ||
|
||
class ModelBroadcast(Broadcast): | ||
|
||
def __init__(self, sc=None, layer=None, pickle_registry=None, path=None, bigdl_type="float"): | ||
""" | ||
Should not be called directly by users -- use L{SparkContext.broadcast()} | ||
instead. | ||
""" | ||
if layer is not None: | ||
self.bigdl_type = layer.bigdl_type | ||
else: | ||
self.bigdl_type = bigdl_type | ||
super(ModelBroadcast, self).__init__(sc, layer, pickle_registry, path) | ||
|
||
def dump(self, value, f): | ||
try: | ||
value.saveModel(f.name, over_write=True) | ||
except Exception as e: | ||
msg = "Could not serialize broadcast: %s" % e.__class__.__name__ | ||
print_exec(sys.stderr) | ||
raise ValueError(msg) | ||
f.close() | ||
return f.name | ||
|
||
def _load(self, path): | ||
return Model.loadModel(path, bigdl_type=self.bigdl_type) | ||
|
||
@property | ||
def value(self): | ||
""" Return the broadcasted value | ||
""" | ||
if not hasattr(self, "_value") and self._path is not None: | ||
self._value = self._load(self._path) | ||
return self._value | ||
|
||
def __reduce__(self): | ||
if self._jbroadcast is None: | ||
raise Exception("Broadcast can only be serialized in driver") | ||
self._pickle_registry.add(self) | ||
return _from_id_and_type, (self._jbroadcast.id(), self.bigdl_type) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,16 +16,17 @@ | |
|
||
import os | ||
import sys | ||
import glob | ||
from py4j.protocol import Py4JJavaError | ||
from py4j.java_gateway import JavaObject | ||
from py4j.java_collections import ListConverter, JavaArray, JavaList, JavaMap, MapConverter | ||
from py4j.java_gateway import JavaGateway, GatewayClient | ||
|
||
from pyspark import RDD, SparkContext | ||
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer | ||
from pyspark.sql import DataFrame, SQLContext | ||
from pyspark.mllib.common import callJavaFunc | ||
from pyspark import SparkConf | ||
from pyspark.files import SparkFiles | ||
import numpy as np | ||
import threading | ||
import tempfile | ||
|
@@ -46,13 +47,18 @@ class SingletonMixin(object): | |
|
||
@classmethod | ||
def instance(cls, | ||
bigdl_type="float"): | ||
bigdl_type, *args): | ||
if not cls._instance: | ||
with cls._lock: | ||
if not cls._instance: | ||
cls._instance = cls(bigdl_type) | ||
cls._instance = cls(bigdl_type, *args) | ||
return cls._instance | ||
|
||
class GatewayWrapper(SingletonMixin): | ||
|
||
def __init__(self, bigdl_type, port=25333): | ||
self.value = JavaGateway(GatewayClient(port=port), auto_convert=True) | ||
|
||
|
||
class JavaCreator(SingletonMixin): | ||
__creator_class=["com.intel.analytics.bigdl.python.api.PythonBigDLKeras"] | ||
|
@@ -74,11 +80,10 @@ def set_creator_class(cls, cclass): | |
JavaCreator.__creator_class = cclass | ||
JavaCreator._instance = None | ||
|
||
def __init__(self, bigdl_type): | ||
sc = get_spark_context() | ||
def __init__(self, bigdl_type, gateway): | ||
self.value = [] | ||
for creator_class in JavaCreator.get_creator_class(): | ||
jclass = getattr(sc._jvm, creator_class) | ||
jclass = getattr(gateway.jvm, creator_class) | ||
if bigdl_type == "float": | ||
self.value.append(getattr(jclass, "ofFloat")()) | ||
elif bigdl_type == "double": | ||
|
@@ -437,6 +442,9 @@ def uniform(self, a, b, size): | |
def init_engine(bigdl_type="float"): | ||
callBigDlFunc(bigdl_type, "initEngine") | ||
|
||
def init_executor_gateway(sc, bigdl_type="float"): | ||
callBigDlFunc(bigdl_type, "initExecutorGateway", sc, sc._gateway._gateway_client.port) | ||
|
||
|
||
def redire_spark_logs(bigdl_type="float", log_path=os.getcwd()+"/bigdl.log"): | ||
""" | ||
|
@@ -556,16 +564,33 @@ def get_spark_sql_context(sc): | |
else: | ||
return SQLContext(sc) # Compatible with Spark1.5.1 | ||
|
||
def _get_port(): | ||
root_dir = SparkFiles.getRootDirectory() | ||
path = os.path.join(root_dir, "gateway_port") | ||
f = open(path) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what if this fails? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll submit another PR and make it report a meaningful error message. |
||
port = int(f.readline()) | ||
return port | ||
|
||
def _get_gateway(): | ||
if SparkFiles._is_running_on_worker: | ||
gateway_port = _get_port() | ||
gateway = GatewayWrapper.instance(None, gateway_port).value | ||
else: | ||
sc = get_spark_context() | ||
gateway = sc._gateway | ||
return gateway | ||
|
||
|
||
def callBigDlFunc(bigdl_type, name, *args): | ||
""" Call API in PythonBigDL """ | ||
sc = get_spark_context() | ||
gateway = _get_gateway() | ||
error = Exception("Cannot find function: %s" % name) | ||
for jinvoker in JavaCreator.instance(bigdl_type=bigdl_type).value: | ||
for jinvoker in JavaCreator.instance(bigdl_type, gateway).value: | ||
# hasattr(jinvoker, name) always return true here, | ||
# so you need to invoke the method to check if it exist or not | ||
try: | ||
api = getattr(jinvoker, name) | ||
result = callJavaFunc(sc, api, *args) | ||
result = callJavaFunc(api, *args) | ||
except Exception as e: | ||
error = e | ||
if "does not exist" not in str(e): | ||
|
@@ -575,7 +600,7 @@ def callBigDlFunc(bigdl_type, name, *args): | |
raise error | ||
|
||
|
||
def _java2py(sc, r, encoding="bytes"): | ||
def _java2py(gateway, r, encoding="bytes"): | ||
if isinstance(r, JavaObject): | ||
clsName = r.getClass().getSimpleName() | ||
# convert RDD into JavaRDD | ||
|
@@ -584,20 +609,20 @@ def _java2py(sc, r, encoding="bytes"): | |
clsName = 'JavaRDD' | ||
|
||
if clsName == 'JavaRDD': | ||
jrdd = sc._jvm.SerDe.javaToPython(r) | ||
return RDD(jrdd, sc) | ||
jrdd = gateway.jvm.SerDe.javaToPython(r) | ||
return RDD(jrdd, get_spark_context()) | ||
|
||
if clsName == 'DataFrame': | ||
return DataFrame(r, get_spark_sql_context(sc)) | ||
return DataFrame(r, get_spark_sql_context(get_spark_context())) | ||
|
||
if clsName == 'Dataset': | ||
return DataFrame(r, get_spark_sql_context(sc)) | ||
return DataFrame(r, get_spark_sql_context(get_spark_context())) | ||
|
||
if clsName in _picklable_classes: | ||
r = sc._jvm.org.apache.spark.bigdl.api.python.BigDLSerDe.dumps(r) | ||
r = gateway.jvm.org.apache.spark.bigdl.api.python.BigDLSerDe.dumps(r) | ||
elif isinstance(r, (JavaArray, JavaList, JavaMap)): | ||
try: | ||
r = sc._jvm.org.apache.spark.bigdl.api.python.BigDLSerDe.dumps( | ||
r = gateway.jvm.org.apache.spark.bigdl.api.python.BigDLSerDe.dumps( | ||
r) | ||
except Py4JJavaError: | ||
pass # not pickable | ||
|
@@ -607,11 +632,12 @@ def _java2py(sc, r, encoding="bytes"): | |
return r | ||
|
||
|
||
def callJavaFunc(sc, func, *args): | ||
def callJavaFunc(func, *args): | ||
""" Call Java Function """ | ||
args = [_py2java(sc, a) for a in args] | ||
gateway = _get_gateway() | ||
args = [_py2java(gateway, a) for a in args] | ||
result = func(*args) | ||
return _java2py(sc, result) | ||
return _java2py(gateway, result) | ||
|
||
|
||
def _to_java_object_rdd(rdd): | ||
|
@@ -627,7 +653,7 @@ def _to_java_object_rdd(rdd): | |
rdd._jrdd, True) | ||
|
||
|
||
def _py2java(sc, obj): | ||
def _py2java(gateway, obj): | ||
""" Convert Python object into Java """ | ||
if isinstance(obj, RDD): | ||
obj = _to_java_object_rdd(obj) | ||
|
@@ -636,13 +662,13 @@ def _py2java(sc, obj): | |
elif isinstance(obj, SparkContext): | ||
obj = obj._jsc | ||
elif isinstance(obj, (list, tuple)): | ||
obj = ListConverter().convert([_py2java(sc, x) for x in obj], | ||
sc._gateway._gateway_client) | ||
obj = ListConverter().convert([_py2java(gateway, x) for x in obj], | ||
gateway._gateway_client) | ||
elif isinstance(obj, dict): | ||
result = {} | ||
for (key, value) in obj.items(): | ||
result[key] = _py2java(sc, value) | ||
obj = MapConverter().convert(result, sc._gateway._gateway_client) | ||
result[key] = _py2java(gateway, value) | ||
obj = MapConverter().convert(result, gateway._gateway_client) | ||
elif isinstance(obj, JavaValue): | ||
obj = obj.value | ||
elif isinstance(obj, JavaObject): | ||
|
@@ -651,7 +677,7 @@ def _py2java(sc, obj): | |
pass | ||
else: | ||
data = bytearray(PickleSerializer().dumps(obj)) | ||
obj = sc._jvm.org.apache.spark.bigdl.api.python.BigDLSerDe.loads(data) | ||
obj = gateway.jvm.org.apache.spark.bigdl.api.python.BigDLSerDe.loads(data) | ||
return obj | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,7 @@ | |
from numpy.testing import assert_allclose, assert_array_equal | ||
from bigdl.util.engine import compare_version | ||
from bigdl.transform.vision.image import * | ||
from bigdl.models.utils.model_broadcast import broadcastModel | ||
np.random.seed(1337) # for reproducibility | ||
|
||
|
||
|
@@ -533,7 +534,7 @@ def test_save_jtensor_dict(self): | |
tensors["tensor1"] = JTensor.from_ndarray(np.random.rand(3, 2)) | ||
tensors["tensor2"] = JTensor.from_ndarray(np.random.rand(3, 2)) | ||
# in old impl, this will throw an exception | ||
_py2java(self.sc, tensors) | ||
_py2java(self.sc._gateway, tensors) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add unittest for the new added ModelBroadcast? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added |
||
|
||
def test_compare_version(self): | ||
assert compare_version("2.1.1", "2.2.0") == -1 | ||
|
@@ -601,5 +602,17 @@ def test_local_predict_multiple_input(self): | |
JTensor.from_ndarray(np.ones([4, 3]))]) | ||
assert result4.shape == (4,) | ||
|
||
def test_model_broadcast(self): | ||
|
||
init_executor_gateway(self.sc) | ||
model = Linear(3, 2) | ||
broadcasted = broadcastModel(self.sc, model) | ||
input_data = np.random.rand(3) | ||
output = self.sc.parallelize([input_data], 1)\ | ||
.map(lambda x: broadcasted.value.forward(x)).first() | ||
expected = model.forward(input_data) | ||
|
||
assert_allclose(output, expected) | ||
|
||
if __name__ == "__main__": | ||
pytest.main([__file__]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when call this method?