Skip to content

Commit 811a66d

Browse files
refactor: Make sure we only write W3C payload into create session command (#1537)
1 parent 92396dc commit 811a66d

File tree

4 files changed

+232
-9
lines changed

4 files changed

+232
-9
lines changed

src/main/java/io/appium/java_client/AppiumDriver.java

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import io.appium.java_client.internal.CapabilityHelpers;
2727
import io.appium.java_client.internal.JsonToMobileElementConverter;
2828
import io.appium.java_client.remote.AppiumCommandExecutor;
29+
import io.appium.java_client.remote.AppiumNewSessionCommandPayload;
2930
import io.appium.java_client.remote.MobileCapabilityType;
3031
import io.appium.java_client.service.local.AppiumDriverLocalService;
3132
import io.appium.java_client.service.local.AppiumServiceBuilder;
@@ -34,6 +35,7 @@
3435
import org.openqa.selenium.DeviceRotation;
3536
import org.openqa.selenium.MutableCapabilities;
3637
import org.openqa.selenium.ScreenOrientation;
38+
import org.openqa.selenium.SessionNotCreatedException;
3739
import org.openqa.selenium.WebDriver;
3840
import org.openqa.selenium.WebDriverException;
3941
import org.openqa.selenium.WebElement;
@@ -44,11 +46,13 @@
4446
import org.openqa.selenium.remote.ErrorHandler;
4547
import org.openqa.selenium.remote.ExecuteMethod;
4648
import org.openqa.selenium.remote.HttpCommandExecutor;
49+
import org.openqa.selenium.remote.RemoteWebDriver;
4750
import org.openqa.selenium.remote.Response;
4851
import org.openqa.selenium.remote.html5.RemoteLocationContext;
4952
import org.openqa.selenium.remote.http.HttpClient;
5053
import org.openqa.selenium.remote.http.HttpMethod;
5154

55+
import java.lang.reflect.Field;
5256
import java.net.URL;
5357
import java.util.Arrays;
5458
import java.util.LinkedHashSet;
@@ -299,14 +303,33 @@ public boolean isBrowser() {
299303

300304
@Override
301305
protected void startSession(Capabilities capabilities) {
302-
super.startSession(capabilities);
303-
// The RemoteWebDriver implementation overrides platformName
304-
// so we need to restore it back to the original value
305-
Object originalPlatformName = capabilities.getCapability(PLATFORM_NAME);
306-
Capabilities originalCaps = super.getCapabilities();
307-
if (originalPlatformName != null && originalCaps instanceof MutableCapabilities) {
308-
((MutableCapabilities) super.getCapabilities()).setCapability(PLATFORM_NAME,
309-
originalPlatformName);
306+
Response response = execute(new AppiumNewSessionCommandPayload(capabilities));
307+
if (response == null) {
308+
throw new SessionNotCreatedException(
309+
"The underlying command executor returned a null response.");
310310
}
311+
312+
Object responseValue = response.getValue();
313+
if (responseValue == null) {
314+
throw new SessionNotCreatedException(
315+
"The underlying command executor returned a response without payload: "
316+
+ response);
317+
}
318+
if (!(responseValue instanceof Map)) {
319+
throw new SessionNotCreatedException(
320+
"The underlying command executor returned a response with a non well formed payload: "
321+
+ response);
322+
}
323+
324+
@SuppressWarnings("unchecked") Map<String, Object> rawCapabilities = (Map<String, Object>) responseValue;
325+
MutableCapabilities returnedCapabilities = new MutableCapabilities(rawCapabilities);
326+
try {
327+
Field capsField = RemoteWebDriver.class.getDeclaredField("capabilities");
328+
capsField.setAccessible(true);
329+
capsField.set(this, returnedCapabilities);
330+
} catch (NoSuchFieldException | IllegalAccessException e) {
331+
throw new WebDriverException(e);
332+
}
333+
setSessionId(response.getSessionId());
311334
}
312335
}

src/main/java/io/appium/java_client/remote/AppiumCommandExecutor.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ private Response createSession(Command command) throws IOException {
145145
throw new SessionNotCreatedException("Session already exists");
146146
}
147147

148-
ProtocolHandshake.Result result = new ProtocolHandshake().createSession(
148+
ProtocolHandshake.Result result = new AppiumProtocolHandshake().createSession(
149149
getClient().with((httpHandler) -> (req) -> {
150150
req.setHeader(IDEMPOTENCY_KEY_HEADER, UUID.randomUUID().toString().toLowerCase());
151151
return httpHandler.execute(req);
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* See the NOTICE file distributed with this work for additional
5+
* information regarding copyright ownership.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.appium.java_client.remote;
18+
19+
import com.google.common.collect.ImmutableMap;
20+
import com.google.common.collect.ImmutableSet;
21+
import org.openqa.selenium.Capabilities;
22+
import org.openqa.selenium.internal.Require;
23+
import org.openqa.selenium.remote.AcceptedW3CCapabilityKeys;
24+
import org.openqa.selenium.remote.CommandPayload;
25+
26+
import java.util.AbstractMap;
27+
import java.util.Map;
28+
29+
import static io.appium.java_client.internal.CapabilityHelpers.APPIUM_PREFIX;
30+
import static org.openqa.selenium.remote.DriverCommand.NEW_SESSION;
31+
32+
public class AppiumNewSessionCommandPayload extends CommandPayload {
33+
private static final AcceptedW3CCapabilityKeys ACCEPTED_W3C_PATTERNS = new AcceptedW3CCapabilityKeys();
34+
35+
/**
36+
* Appends "appium:" prefix to all non-prefixed non-standard capabilities.
37+
*
38+
* @param possiblyInvalidCapabilities user-provided capabilities mapping.
39+
* @return Fixed capabilities mapping.
40+
*/
41+
private static Map<String, Object> makeW3CSafe(Capabilities possiblyInvalidCapabilities) {
42+
Require.nonNull("Capabilities", possiblyInvalidCapabilities);
43+
44+
return possiblyInvalidCapabilities.asMap().entrySet().stream()
45+
.map((entry) -> ACCEPTED_W3C_PATTERNS.test(entry.getKey())
46+
? entry
47+
: new AbstractMap.SimpleEntry<>(
48+
String.format("%s%s", APPIUM_PREFIX, entry.getKey()), entry.getValue()))
49+
.collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
50+
}
51+
52+
/**
53+
* Overrides the default new session behavior to
54+
* only handle W3C capabilities.
55+
*
56+
* @param capabilities User-provided capabilities.
57+
*/
58+
public AppiumNewSessionCommandPayload(Capabilities capabilities) {
59+
super(NEW_SESSION, ImmutableMap.of(
60+
"capabilities", ImmutableSet.of(makeW3CSafe(capabilities)),
61+
"desiredCapabilities", capabilities
62+
));
63+
}
64+
}
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* See the NOTICE file distributed with this work for additional
5+
* information regarding copyright ownership.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.appium.java_client.remote;
18+
19+
import com.google.common.io.CountingOutputStream;
20+
import com.google.common.io.FileBackedOutputStream;
21+
import org.openqa.selenium.Capabilities;
22+
import org.openqa.selenium.ImmutableCapabilities;
23+
import org.openqa.selenium.SessionNotCreatedException;
24+
import org.openqa.selenium.WebDriverException;
25+
import org.openqa.selenium.internal.Either;
26+
import org.openqa.selenium.json.Json;
27+
import org.openqa.selenium.json.JsonOutput;
28+
import org.openqa.selenium.remote.Command;
29+
import org.openqa.selenium.remote.NewSessionPayload;
30+
import org.openqa.selenium.remote.ProtocolHandshake;
31+
import org.openqa.selenium.remote.http.HttpHandler;
32+
33+
import java.io.BufferedInputStream;
34+
import java.io.IOException;
35+
import java.io.InputStream;
36+
import java.io.OutputStreamWriter;
37+
import java.io.Writer;
38+
import java.lang.reflect.InvocationTargetException;
39+
import java.lang.reflect.Method;
40+
import java.util.Map;
41+
import java.util.Set;
42+
import java.util.stream.Stream;
43+
44+
import static java.nio.charset.StandardCharsets.UTF_8;
45+
46+
@SuppressWarnings("UnstableApiUsage")
47+
public class AppiumProtocolHandshake extends ProtocolHandshake {
48+
private static void writeJsonPayload(NewSessionPayload srcPayload, Appendable destination) {
49+
try (JsonOutput json = new Json().newOutput(destination)) {
50+
json.beginObject();
51+
52+
json.name("capabilities");
53+
json.beginObject();
54+
55+
json.name("firstMatch");
56+
json.beginArray();
57+
json.beginObject();
58+
json.endObject();
59+
json.endArray();
60+
61+
json.name("alwaysMatch");
62+
try {
63+
Method getW3CMethod = NewSessionPayload.class.getDeclaredMethod("getW3C");
64+
getW3CMethod.setAccessible(true);
65+
//noinspection unchecked
66+
((Stream<Map<String, Object>>) getW3CMethod.invoke(srcPayload))
67+
.findFirst()
68+
.map(json::write)
69+
.orElseGet(() -> {
70+
json.beginObject();
71+
json.endObject();
72+
return null;
73+
});
74+
} catch (NoSuchMethodException | InvocationTargetException | IllegalAccessException e) {
75+
throw new WebDriverException(e);
76+
}
77+
78+
json.endObject(); // Close "capabilities" object
79+
80+
try {
81+
Method writeMetaDataMethod = NewSessionPayload.class.getDeclaredMethod(
82+
"writeMetaData", JsonOutput.class);
83+
writeMetaDataMethod.setAccessible(true);
84+
writeMetaDataMethod.invoke(srcPayload, json);
85+
} catch (NoSuchMethodException | InvocationTargetException | IllegalAccessException e) {
86+
throw new WebDriverException(e);
87+
}
88+
89+
json.endObject();
90+
}
91+
}
92+
93+
@Override
94+
public Result createSession(HttpHandler client, Command command) throws IOException {
95+
//noinspection unchecked
96+
Capabilities desired = ((Set<Map<String, Object>>) command.getParameters().get("capabilities"))
97+
.stream()
98+
.findAny()
99+
.map(ImmutableCapabilities::new)
100+
.orElseGet(ImmutableCapabilities::new);
101+
try (NewSessionPayload payload = NewSessionPayload.create(desired)) {
102+
Either<SessionNotCreatedException, Result> result = createSession(client, payload);
103+
if (result.isRight()) {
104+
return result.right();
105+
}
106+
throw result.left();
107+
}
108+
}
109+
110+
@Override
111+
public Either<SessionNotCreatedException, Result> createSession(
112+
HttpHandler client, NewSessionPayload payload) throws IOException {
113+
int threshold = (int) Math.min(Runtime.getRuntime().freeMemory() / 10, Integer.MAX_VALUE);
114+
FileBackedOutputStream os = new FileBackedOutputStream(threshold);
115+
116+
try (CountingOutputStream counter = new CountingOutputStream(os);
117+
Writer writer = new OutputStreamWriter(counter, UTF_8)) {
118+
writeJsonPayload(payload, writer);
119+
120+
try (InputStream rawIn = os.asByteSource().openBufferedStream();
121+
BufferedInputStream contentStream = new BufferedInputStream(rawIn)) {
122+
Method createSessionMethod = ProtocolHandshake.class.getDeclaredMethod("createSession",
123+
HttpHandler.class, InputStream.class, long.class);
124+
createSessionMethod.setAccessible(true);
125+
//noinspection unchecked
126+
return (Either<SessionNotCreatedException, Result>) createSessionMethod.invoke(
127+
this, client, contentStream, counter.getCount()
128+
);
129+
} catch (NoSuchMethodException | InvocationTargetException | IllegalAccessException e) {
130+
throw new WebDriverException(e);
131+
}
132+
} finally {
133+
os.reset();
134+
}
135+
}
136+
}

0 commit comments

Comments
 (0)