Skip to content

Commit 8efc6e9

Browse files
author
Marcelo Vanzin
committed
[SPARK-20922][CORE] Add whitelist of classes that can be deserialized by the launcher.
Blindly deserializing classes using Java serialization opens the code up to issues in other libraries, since just deserializing data from a stream may end up execution code (think readObject()). Since the launcher protocol is pretty self-contained, there's just a handful of classes it legitimately needs to deserialize, and they're in just two packages, so add a filter that throws errors if classes from any other package show up in the stream. This also maintains backwards compatibility (the updated launcher code can still communicate with the backend code in older Spark releases). Tested with new and existing unit tests. Author: Marcelo Vanzin <vanzin@cloudera.com> Closes #18166 from vanzin/SPARK-20922.
1 parent 640afa4 commit 8efc6e9

File tree

3 files changed

+121
-27
lines changed

3 files changed

+121
-27
lines changed
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.launcher;
19+
20+
import java.io.InputStream;
21+
import java.io.IOException;
22+
import java.io.ObjectInputStream;
23+
import java.io.ObjectStreamClass;
24+
import java.util.Arrays;
25+
import java.util.List;
26+
27+
/**
28+
* An object input stream that only allows classes used by the launcher protocol to be in the
29+
* serialized stream. See SPARK-20922.
30+
*/
31+
class FilteredObjectInputStream extends ObjectInputStream {
32+
33+
private static final List<String> ALLOWED_PACKAGES = Arrays.asList(
34+
"org.apache.spark.launcher.",
35+
"java.lang.");
36+
37+
FilteredObjectInputStream(InputStream is) throws IOException {
38+
super(is);
39+
}
40+
41+
@Override
42+
protected Class<?> resolveClass(ObjectStreamClass desc)
43+
throws IOException, ClassNotFoundException {
44+
45+
boolean isValid = ALLOWED_PACKAGES.stream().anyMatch(p -> desc.getName().startsWith(p));
46+
if (!isValid) {
47+
throw new IllegalArgumentException(
48+
String.format("Unexpected class in stream: %s", desc.getName()));
49+
}
50+
return super.resolveClass(desc);
51+
}
52+
53+
}

launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import java.io.Closeable;
2121
import java.io.EOFException;
2222
import java.io.IOException;
23-
import java.io.ObjectInputStream;
2423
import java.io.ObjectOutputStream;
2524
import java.net.Socket;
2625
import java.util.logging.Level;
@@ -53,7 +52,7 @@ abstract class LauncherConnection implements Closeable, Runnable {
5352
@Override
5453
public void run() {
5554
try {
56-
ObjectInputStream in = new ObjectInputStream(socket.getInputStream());
55+
FilteredObjectInputStream in = new FilteredObjectInputStream(socket.getInputStream());
5756
while (!closed) {
5857
Message msg = (Message) in.readObject();
5958
handle(msg);

launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java

Lines changed: 67 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,11 @@
1919

2020
import java.io.Closeable;
2121
import java.io.IOException;
22+
import java.io.ObjectInputStream;
2223
import java.net.InetAddress;
2324
import java.net.Socket;
25+
import java.util.Arrays;
26+
import java.util.List;
2427
import java.util.concurrent.BlockingQueue;
2528
import java.util.concurrent.LinkedBlockingQueue;
2629
import java.util.concurrent.Semaphore;
@@ -120,31 +123,7 @@ public void testTimeout() throws Exception {
120123
Socket s = new Socket(InetAddress.getLoopbackAddress(),
121124
LauncherServer.getServerInstance().getPort());
122125
client = new TestClient(s);
123-
124-
// Try a few times since the client-side socket may not reflect the server-side close
125-
// immediately.
126-
boolean helloSent = false;
127-
int maxTries = 10;
128-
for (int i = 0; i < maxTries; i++) {
129-
try {
130-
if (!helloSent) {
131-
client.send(new Hello(handle.getSecret(), "1.4.0"));
132-
helloSent = true;
133-
} else {
134-
client.send(new SetAppId("appId"));
135-
}
136-
fail("Expected exception caused by connection timeout.");
137-
} catch (IllegalStateException | IOException e) {
138-
// Expected.
139-
break;
140-
} catch (AssertionError e) {
141-
if (i < maxTries - 1) {
142-
Thread.sleep(100);
143-
} else {
144-
throw new AssertionError("Test failed after " + maxTries + " attempts.", e);
145-
}
146-
}
147-
}
126+
waitForError(client, handle.getSecret());
148127
} finally {
149128
SparkLauncher.launcherConfig.remove(SparkLauncher.CHILD_CONNECTION_TIMEOUT);
150129
kill(handle);
@@ -183,6 +162,25 @@ public void infoChanged(SparkAppHandle handle) {
183162
}
184163
}
185164

165+
@Test
166+
public void testStreamFiltering() throws Exception {
167+
ChildProcAppHandle handle = LauncherServer.newAppHandle();
168+
TestClient client = null;
169+
try {
170+
Socket s = new Socket(InetAddress.getLoopbackAddress(),
171+
LauncherServer.getServerInstance().getPort());
172+
173+
client = new TestClient(s);
174+
client.send(new EvilPayload());
175+
waitForError(client, handle.getSecret());
176+
assertEquals(0, EvilPayload.EVIL_BIT);
177+
} finally {
178+
kill(handle);
179+
close(client);
180+
client.clientThread.join();
181+
}
182+
}
183+
186184
private void kill(SparkAppHandle handle) {
187185
if (handle != null) {
188186
handle.kill();
@@ -199,6 +197,35 @@ private void close(Closeable c) {
199197
}
200198
}
201199

200+
/**
201+
* Try a few times to get a client-side error, since the client-side socket may not reflect the
202+
* server-side close immediately.
203+
*/
204+
private void waitForError(TestClient client, String secret) throws Exception {
205+
boolean helloSent = false;
206+
int maxTries = 10;
207+
for (int i = 0; i < maxTries; i++) {
208+
try {
209+
if (!helloSent) {
210+
client.send(new Hello(secret, "1.4.0"));
211+
helloSent = true;
212+
} else {
213+
client.send(new SetAppId("appId"));
214+
}
215+
fail("Expected error but message went through.");
216+
} catch (IllegalStateException | IOException e) {
217+
// Expected.
218+
break;
219+
} catch (AssertionError e) {
220+
if (i < maxTries - 1) {
221+
Thread.sleep(100);
222+
} else {
223+
throw new AssertionError("Test failed after " + maxTries + " attempts.", e);
224+
}
225+
}
226+
}
227+
}
228+
202229
private static class TestClient extends LauncherConnection {
203230

204231
final BlockingQueue<Message> inbound;
@@ -220,4 +247,19 @@ protected void handle(Message msg) throws IOException {
220247

221248
}
222249

250+
private static class EvilPayload extends LauncherProtocol.Message {
251+
252+
static int EVIL_BIT = 0;
253+
254+
// This field should cause the launcher server to throw an error and not deserialize the
255+
// message.
256+
private List<String> notAllowedField = Arrays.asList("disallowed");
257+
258+
private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException {
259+
stream.defaultReadObject();
260+
EVIL_BIT = 1;
261+
}
262+
263+
}
264+
223265
}

0 commit comments

Comments
 (0)