Skip to content

Commit 44c622b

Browse files
committed
Add support to run specific unittests and/or doctests in python/run-tests script
1 parent 676bbb2 commit 44c622b

File tree

2 files changed

+56
-32
lines changed

2 files changed

+56
-32
lines changed

python/run-tests-with-coverage

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,6 @@ export SPARK_CONF_DIR="$COVERAGE_DIR/conf"
5050
# This environment variable enables the coverage.
5151
export COVERAGE_PROCESS_START="$FWDIR/.coveragerc"
5252

53-
# If you'd like to run a specific unittest class, you could do such as
54-
# SPARK_TESTING=1 ../bin/pyspark pyspark.sql.tests VectorizedUDFTests
5553
./run-tests "$@"
5654

5755
# Don't run coverage for the coverage command itself

python/run-tests.py

Lines changed: 56 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from __future__ import print_function
2121
import logging
22-
from optparse import OptionParser
22+
from optparse import OptionParser, OptionGroup
2323
import os
2424
import re
2525
import shutil
@@ -93,17 +93,18 @@ def run_individual_python_test(target_dir, test_name, pyspark_python):
9393
"pyspark-shell"
9494
]
9595
env["PYSPARK_SUBMIT_ARGS"] = " ".join(spark_args)
96-
97-
LOGGER.info("Starting test(%s): %s", pyspark_python, test_name)
96+
str_test_name = " ".join(test_name)
97+
LOGGER.info("Starting test(%s): %s", pyspark_python, str_test_name)
9898
start_time = time.time()
9999
try:
100100
per_test_output = tempfile.TemporaryFile()
101101
retcode = subprocess.Popen(
102-
[os.path.join(SPARK_HOME, "bin/pyspark"), test_name],
102+
(os.path.join(SPARK_HOME, "bin/pyspark"), ) + test_name,
103103
stderr=per_test_output, stdout=per_test_output, env=env).wait()
104104
shutil.rmtree(tmp_dir, ignore_errors=True)
105105
except:
106-
LOGGER.exception("Got exception while running %s with %s", test_name, pyspark_python)
106+
LOGGER.exception(
107+
"Got exception while running %s with %s", str_test_name, pyspark_python)
107108
# Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if
108109
# this code is invoked from a thread other than the main thread.
109110
os._exit(1)
@@ -124,7 +125,8 @@ def run_individual_python_test(target_dir, test_name, pyspark_python):
124125
except:
125126
LOGGER.exception("Got an exception while trying to print failed test output")
126127
finally:
127-
print_red("\nHad test failures in %s with %s; see logs." % (test_name, pyspark_python))
128+
print_red("\nHad test failures in %s with %s; see logs." % (
129+
str_test_name, pyspark_python))
128130
# Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if
129131
# this code is invoked from a thread other than the main thread.
130132
os._exit(-1)
@@ -140,7 +142,7 @@ def run_individual_python_test(target_dir, test_name, pyspark_python):
140142
decoded_lines))
141143
skipped_counts = len(skipped_tests)
142144
if skipped_counts > 0:
143-
key = (pyspark_python, test_name)
145+
key = (pyspark_python, str_test_name)
144146
SKIPPED_TESTS[key] = skipped_tests
145147
per_test_output.close()
146148
except:
@@ -152,11 +154,11 @@ def run_individual_python_test(target_dir, test_name, pyspark_python):
152154
os._exit(-1)
153155
if skipped_counts != 0:
154156
LOGGER.info(
155-
"Finished test(%s): %s (%is) ... %s tests were skipped", pyspark_python, test_name,
156-
duration, skipped_counts)
157+
"Finished test(%s): %s (%is) ... %s tests were skipped", pyspark_python,
158+
str_test_name, duration, skipped_counts)
157159
else:
158160
LOGGER.info(
159-
"Finished test(%s): %s (%is)", pyspark_python, test_name, duration)
161+
"Finished test(%s): %s (%is)", pyspark_python, str_test_name, duration)
160162

161163

162164
def get_default_python_executables():
@@ -190,6 +192,20 @@ def parse_opts():
190192
help="Enable additional debug logging"
191193
)
192194

195+
group = OptionGroup(parser, "Developer Options")
196+
group.add_option(
197+
"--testnames", type="string",
198+
default=None,
199+
help=(
200+
"A comma-separated list of specific modules, classes and functions of doctest "
201+
"or unittest to test. "
202+
"For example, 'pyspark.sql.foo' to run the module as unittests or doctests, "
203+
"'pyspark.sql.tests FooTests' to run the specific class of unittests, "
204+
"'pyspark.sql.tests FooTests.test_foo' to run the specific unittest in the class. "
205+
"'--modules' option is ignored if they are given.")
206+
)
207+
parser.add_option_group(group)
208+
193209
(opts, args) = parser.parse_args()
194210
if args:
195211
parser.error("Unsupported arguments: %s" % ' '.join(args))
@@ -213,25 +229,31 @@ def _check_coverage(python_exec):
213229

214230
def main():
215231
opts = parse_opts()
216-
if (opts.verbose):
232+
if opts.verbose:
217233
log_level = logging.DEBUG
218234
else:
219235
log_level = logging.INFO
236+
should_test_modules = opts.testnames is None
220237
logging.basicConfig(stream=sys.stdout, level=log_level, format="%(message)s")
221238
LOGGER.info("Running PySpark tests. Output is in %s", LOG_FILE)
222239
if os.path.exists(LOG_FILE):
223240
os.remove(LOG_FILE)
224241
python_execs = opts.python_executables.split(',')
225-
modules_to_test = []
226-
for module_name in opts.modules.split(','):
227-
if module_name in python_modules:
228-
modules_to_test.append(python_modules[module_name])
229-
else:
230-
print("Error: unrecognized module '%s'. Supported modules: %s" %
231-
(module_name, ", ".join(python_modules)))
232-
sys.exit(-1)
233242
LOGGER.info("Will test against the following Python executables: %s", python_execs)
234-
LOGGER.info("Will test the following Python modules: %s", [x.name for x in modules_to_test])
243+
244+
if should_test_modules:
245+
modules_to_test = []
246+
for module_name in opts.modules.split(','):
247+
if module_name in python_modules:
248+
modules_to_test.append(python_modules[module_name])
249+
else:
250+
print("Error: unrecognized module '%s'. Supported modules: %s" %
251+
(module_name, ", ".join(python_modules)))
252+
sys.exit(-1)
253+
LOGGER.info("Will test the following Python modules: %s", [x.name for x in modules_to_test])
254+
else:
255+
testnames_to_test = opts.testnames.split(',')
256+
LOGGER.info("Will test the following Python tests: %s", testnames_to_test)
235257

236258
task_queue = Queue.PriorityQueue()
237259
for python_exec in python_execs:
@@ -246,16 +268,20 @@ def main():
246268
LOGGER.debug("%s python_implementation is %s", python_exec, python_implementation)
247269
LOGGER.debug("%s version is: %s", python_exec, subprocess_check_output(
248270
[python_exec, "--version"], stderr=subprocess.STDOUT, universal_newlines=True).strip())
249-
for module in modules_to_test:
250-
if python_implementation not in module.blacklisted_python_implementations:
251-
for test_goal in module.python_test_goals:
252-
heavy_tests = ['pyspark.streaming.tests', 'pyspark.mllib.tests',
253-
'pyspark.tests', 'pyspark.sql.tests', 'pyspark.ml.tests']
254-
if any(map(lambda prefix: test_goal.startswith(prefix), heavy_tests)):
255-
priority = 0
256-
else:
257-
priority = 100
258-
task_queue.put((priority, (python_exec, test_goal)))
271+
if should_test_modules:
272+
for module in modules_to_test:
273+
if python_implementation not in module.blacklisted_python_implementations:
274+
for test_goal in module.python_test_goals:
275+
heavy_tests = ['pyspark.streaming.tests', 'pyspark.mllib.tests',
276+
'pyspark.tests', 'pyspark.sql.tests', 'pyspark.ml.tests']
277+
if any(map(lambda prefix: test_goal.startswith(prefix), heavy_tests)):
278+
priority = 0
279+
else:
280+
priority = 100
281+
task_queue.put((priority, (python_exec, (test_goal, ))))
282+
else:
283+
for test_goal in testnames_to_test:
284+
task_queue.put((0, (python_exec, tuple(test_goal.split()))))
259285

260286
# Create the target directory before starting tasks to avoid races.
261287
target_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'target'))

0 commit comments

Comments
 (0)