Skip to content

Commit 6433570

Browse files
authored
Merge pull request apache#77 from mesosphere/add-pyspark-tests
Add pyspark test
2 parents e67eaf8 + 75086bc commit 6433570

File tree

3 files changed

+100
-12
lines changed

3 files changed

+100
-12
lines changed

tests/jobs/PySparkTestInclude.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
def func():
2+
print "Import is working"

tests/jobs/pi_with_include.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from __future__ import print_function
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one or more
4+
# contributor license agreements. See the NOTICE file distributed with
5+
# this work for additional information regarding copyright ownership.
6+
# The ASF licenses this file to You under the Apache License, Version 2.0
7+
# (the "License"); you may not use this file except in compliance with
8+
# the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
19+
import sys
20+
from random import random
21+
from operator import add
22+
23+
from pyspark.sql import SparkSession
24+
25+
import PySparkTestInclude
26+
27+
if __name__ == "__main__":
28+
"""
29+
Usage: pi [partitions]
30+
"""
31+
32+
# Make sure we can include this user-provided module
33+
PySparkTestInclude.func()
34+
35+
spark = SparkSession\
36+
.builder\
37+
.appName("PythonPi")\
38+
.getOrCreate()
39+
40+
partitions = int(sys.argv[1]) if len(sys.argv) > 1 else 2
41+
n = 100000 * partitions
42+
43+
def f(_):
44+
x = random() * 2 - 1
45+
y = random() * 2 - 1
46+
return 1 if x ** 2 + y ** 2 < 1 else 0
47+
48+
count = spark.sparkContext.parallelize(range(1, n + 1), partitions).map(f).reduce(add)
49+
print("Pi is roughly %f" % (4.0 * count / n))
50+
51+
spark.stop()

tests/test.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,21 @@
1313
import shakedown
1414

1515

16-
def upload_jar(jar):
16+
def upload_file(file_path):
1717
conn = S3Connection(os.environ['AWS_ACCESS_KEY_ID'], os.environ['AWS_SECRET_ACCESS_KEY'])
1818
bucket = conn.get_bucket(os.environ['S3_BUCKET'])
19-
basename = os.path.basename(jar)
19+
basename = os.path.basename(file_path)
20+
21+
if basename.endswith('.jar'):
22+
content_type = 'application/java-archive'
23+
elif basename.endswith('.py'):
24+
content_type = 'application/x-python'
25+
else:
26+
raise ValueError("Unexpected file type: {}. Expected .jar or .py file.".format(basename))
2027

2128
key = Key(bucket, '{}/{}'.format(os.environ['S3_PREFIX'], basename))
22-
key.metadata = {'Content-Type': 'application/java-archive'}
23-
key.set_contents_from_filename(jar)
29+
key.metadata = {'Content-Type': content_type}
30+
key.set_contents_from_filename(file_path)
2431
key.make_public()
2532

2633
jar_url = "http://{0}.s3.amazonaws.com/{1}/{2}".format(
@@ -31,10 +38,18 @@ def upload_jar(jar):
3138
return jar_url
3239

3340

34-
def submit_job(jar_url):
35-
spark_job_runner_args = 'http://leader.mesos:5050 dcos \\"*\\" spark:only 2'
36-
submit_args = "-Dspark.driver.memory=2g --class com.typesafe.spark.test.mesos.framework.runners.SparkJobRunner {0} {1}".format(
37-
jar_url, spark_job_runner_args)
41+
def submit_job(app_resource_url, app_args, app_class, py_files):
42+
if app_class is not None:
43+
app_class_option = '--class {} '.format(app_class)
44+
else:
45+
app_class_option = ''
46+
if py_files is not None:
47+
py_files_option = '--py-files {} '.format(py_files)
48+
else:
49+
py_files_option = ''
50+
51+
submit_args = "-Dspark.driver.memory=2g {0}{1}{2} {3}".format(
52+
app_class_option, py_files_option, app_resource_url, app_args)
3853
cmd = 'dcos --log-level=DEBUG spark --verbose run --submit-args="{0}"'.format(submit_args)
3954
print('Running {}'.format(cmd))
4055
stdout = subprocess.check_output(cmd, shell=True).decode('utf-8')
@@ -52,14 +67,34 @@ def task_log(task_id):
5267
return stdout
5368

5469

55-
def main():
56-
jar_url = upload_jar(os.getenv('TEST_JAR_PATH'))
57-
task_id = submit_job(jar_url)
70+
def run_tests(app_path, app_args, expected_output, app_class=None, py_file_path=None):
71+
app_resource_url = upload_file(app_path)
72+
if py_file_path is not None:
73+
py_file_url = upload_file(py_file_path)
74+
else:
75+
py_file_url = None
76+
task_id = submit_job(app_resource_url, app_args, app_class, py_file_url)
5877
print('Waiting for task id={} to complete'.format(task_id))
5978
shakedown.wait_for_task_completion(task_id)
6079
log = task_log(task_id)
6180
print(log)
62-
assert "All tests passed" in log
81+
assert expected_output in log
82+
83+
84+
def main():
85+
spark_job_runner_args = 'http://leader.mesos:5050 dcos \\"*\\" spark:only 2'
86+
run_tests(os.getenv('TEST_JAR_PATH'),
87+
spark_job_runner_args,
88+
"All tests passed",
89+
app_class='com.typesafe.spark.test.mesos.framework.runners.SparkJobRunner')
90+
91+
script_dir = os.path.dirname(os.path.abspath(__file__))
92+
python_script_path = os.path.join(script_dir, 'jobs', 'pi_with_include.py')
93+
py_file_path = os.path.join(script_dir, 'jobs', 'PySparkTestInclude.py')
94+
run_tests(python_script_path,
95+
'30',
96+
"Pi is roughly 3",
97+
py_file_path=py_file_path)
6398

6499

65100
if __name__ == '__main__':

0 commit comments

Comments
 (0)