Skip to content

Commit

Permalink
add integration test
Browse files Browse the repository at this point in the history
  • Loading branch information
yangw1234 committed Feb 9, 2018
1 parent e1c4ffe commit b9c3e5a
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions pyspark/test/bigdl/test_simple_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from numpy.testing import assert_allclose, assert_array_equal
from bigdl.util.engine import compare_version
from bigdl.transform.vision.image import *
from bigdl.models.utils.model_broadcast import broadcastModel
np.random.seed(1337) # for reproducibility


Expand Down Expand Up @@ -601,5 +602,17 @@ def test_local_predict_multiple_input(self):
JTensor.from_ndarray(np.ones([4, 3]))])
assert result4.shape == (4,)

def test_model_broadcast(self):

init_executor_gateway(self.sc)
model = Linear(3, 2)
broadcasted = broadcastModel(self.sc, model)
input_data = np.random.rand(3)
output = self.sc.parallelize([input_data], 1)\
.map(lambda x: broadcasted.value.forward(x)).first()
expected = model.forward(input_data)

assert_allclose(output, expected)

if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit b9c3e5a

Please sign in to comment.