forked from yahoo/TensorFlowOnSpark
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
40 lines (29 loc) · 1.22 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import os
import unittest
from pyspark import SparkConf, SparkContext
from pyspark.sql import SparkSession
class SparkTest(unittest.TestCase):
"""Base class for unittests using Spark. Sets up and tears down a cluster per test class"""
@classmethod
def setUpClass(cls):
master = os.getenv('MASTER')
assert master is not None, "Please start a Spark standalone cluster and export MASTER to your env."
num_workers = os.getenv('SPARK_WORKER_INSTANCES')
assert num_workers is not None, "Please export SPARK_WORKER_INSTANCES to your env."
cls.num_workers = int(num_workers)
spark_jars = os.getenv('SPARK_CLASSPATH')
assert spark_jars and 'tensorflow-hadoop' in spark_jars, "Please add path to tensorflow-hadoop-*.jar to SPARK_CLASSPATH."
cls.conf = SparkConf().set('spark.jars', spark_jars)
cls.sc = SparkContext(master, cls.__name__, conf=cls.conf)
cls.spark = SparkSession.builder.getOrCreate()
@classmethod
def tearDownClass(cls):
cls.spark.stop()
cls.sc.stop()
class SimpleTest(SparkTest):
"""Check that basic Spark is working"""
def test_spark(self):
sum = self.sc.parallelize(range(1000)).sum()
self.assertEqual(sum, 499500)
if __name__ == '__main__':
unittest.main()