|
14 | 14 | # See the License for the specific language governing permissions and |
15 | 15 | # limitations under the License. |
16 | 16 | # |
| 17 | +import os |
| 18 | +import sys |
17 | 19 |
|
18 | 20 | from py4j.java_collections import ListConverter |
19 | 21 | from py4j.java_gateway import java_import |
20 | 22 |
|
21 | | -from pyspark import RDD |
| 23 | +from pyspark import RDD, SparkConf |
22 | 24 | from pyspark.serializers import UTF8Deserializer, CloudPickleSerializer |
23 | 25 | from pyspark.context import SparkContext |
24 | 26 | from pyspark.storagelevel import StorageLevel |
@@ -75,41 +77,81 @@ class StreamingContext(object): |
75 | 77 | respectively. `context.awaitTransformation()` allows the current thread |
76 | 78 | to wait for the termination of the context by `stop()` or by an exception. |
77 | 79 | """ |
| 80 | + _transformerSerializer = None |
78 | 81 |
|
79 | | - def __init__(self, sparkContext, duration): |
| 82 | + def __init__(self, sparkContext, duration=None, jssc=None): |
80 | 83 | """ |
81 | 84 | Create a new StreamingContext. |
82 | 85 |
|
83 | 86 | @param sparkContext: L{SparkContext} object. |
84 | 87 | @param duration: number of seconds. |
85 | 88 | """ |
| 89 | + |
86 | 90 | self._sc = sparkContext |
87 | 91 | self._jvm = self._sc._jvm |
88 | | - self._start_callback_server() |
89 | | - self._jssc = self._initialize_context(self._sc, duration) |
| 92 | + self._jssc = jssc or self._initialize_context(self._sc, duration) |
| 93 | + |
| 94 | + def _initialize_context(self, sc, duration): |
| 95 | + self._ensure_initialized() |
| 96 | + return self._jvm.JavaStreamingContext(sc._jsc, self._jduration(duration)) |
| 97 | + |
| 98 | + def _jduration(self, seconds): |
| 99 | + """ |
| 100 | + Create Duration object given number of seconds |
| 101 | + """ |
| 102 | + return self._jvm.Duration(int(seconds * 1000)) |
90 | 103 |
|
91 | | - def _start_callback_server(self): |
92 | | - gw = self._sc._gateway |
| 104 | + @classmethod |
| 105 | + def _ensure_initialized(cls): |
| 106 | + SparkContext._ensure_initialized() |
| 107 | + gw = SparkContext._gateway |
| 108 | + # start callback server |
93 | 109 | # getattr will fallback to JVM |
94 | 110 | if "_callback_server" not in gw.__dict__: |
95 | 111 | _daemonize_callback_server() |
96 | 112 | gw._start_callback_server(gw._python_proxy_port) |
97 | | - gw._python_proxy_port = gw._callback_server.port # update port with real port |
98 | 113 |
|
99 | | - def _initialize_context(self, sc, duration): |
100 | | - java_import(self._jvm, "org.apache.spark.streaming.*") |
101 | | - java_import(self._jvm, "org.apache.spark.streaming.api.java.*") |
102 | | - java_import(self._jvm, "org.apache.spark.streaming.api.python.*") |
| 114 | + java_import(gw.jvm, "org.apache.spark.streaming.*") |
| 115 | + java_import(gw.jvm, "org.apache.spark.streaming.api.java.*") |
| 116 | + java_import(gw.jvm, "org.apache.spark.streaming.api.python.*") |
103 | 117 | # register serializer for RDDFunction |
104 | | - ser = RDDFunctionSerializer(self._sc, CloudPickleSerializer()) |
105 | | - self._jvm.PythonDStream.registerSerializer(ser) |
106 | | - return self._jvm.JavaStreamingContext(sc._jsc, self._jduration(duration)) |
| 118 | + # it happens before creating SparkContext when loading from checkpointing |
| 119 | + cls._transformerSerializer = RDDFunctionSerializer(SparkContext._active_spark_context, |
| 120 | + CloudPickleSerializer(), gw) |
| 121 | + gw.jvm.PythonDStream.registerSerializer(cls._transformerSerializer) |
107 | 122 |
|
108 | | - def _jduration(self, seconds): |
| 123 | + @classmethod |
| 124 | + def getOrCreate(cls, path, setupFunc): |
109 | 125 | """ |
110 | | - Create Duration object given number of seconds |
| 126 | + Get the StreamingContext from checkpoint file at `path`, or setup |
| 127 | + it by `setupFunc`. |
| 128 | +
|
| 129 | + :param path: directory of checkpoint |
| 130 | + :param setupFunc: a function used to create StreamingContext and |
| 131 | + setup DStreams. |
| 132 | + :return: a StreamingContext |
111 | 133 | """ |
112 | | - return self._jvm.Duration(int(seconds * 1000)) |
| 134 | + if not os.path.exists(path) or not os.path.isdir(path) or not os.listdir(path): |
| 135 | + ssc = setupFunc() |
| 136 | + ssc.checkpoint(path) |
| 137 | + return ssc |
| 138 | + |
| 139 | + cls._ensure_initialized() |
| 140 | + gw = SparkContext._gateway |
| 141 | + |
| 142 | + try: |
| 143 | + jssc = gw.jvm.JavaStreamingContext(path) |
| 144 | + except Exception: |
| 145 | + print >>sys.stderr, "failed to load StreamingContext from checkpoint" |
| 146 | + raise |
| 147 | + |
| 148 | + jsc = jssc.sparkContext() |
| 149 | + conf = SparkConf(_jconf=jsc.getConf()) |
| 150 | + sc = SparkContext(conf=conf, gateway=gw, jsc=jsc) |
| 151 | + # update ctx in serializer |
| 152 | + SparkContext._active_spark_context = sc |
| 153 | + cls._transformerSerializer.ctx = sc |
| 154 | + return StreamingContext(sc, None, jssc) |
113 | 155 |
|
114 | 156 | @property |
115 | 157 | def sparkContext(self): |
|
0 commit comments