Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjust the handling of RSV1/RSV2/RSV3 in the translateSingleFrame #1232

Merged
merged 2 commits into from
Apr 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 37 additions & 14 deletions src/main/java/org/java_websocket/drafts/Draft_6455.java
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,23 @@ public class Draft_6455 extends Draft {
/**
* Attribute for the used extension in this draft
*/
private IExtension extension = new DefaultExtension();
private IExtension negotiatedExtension = new DefaultExtension();

/**
* Attribute for the default extension
*/
private IExtension defaultExtension = new DefaultExtension();

/**
* Attribute for all available extension in this draft
*/
private List<IExtension> knownExtensions;

/**
* Current active extension used to decode messages
*/
private IExtension currentDecodingExtension;

/**
* Attribute for the used protocol in this draft
*/
Expand Down Expand Up @@ -241,10 +251,11 @@ public Draft_6455(List<IExtension> inputExtensions, List<IProtocol> inputProtoco
knownExtensions.addAll(inputExtensions);
//We always add the DefaultExtension to implement the normal RFC 6455 specification
if (!hasDefault) {
knownExtensions.add(this.knownExtensions.size(), extension);
knownExtensions.add(this.knownExtensions.size(), negotiatedExtension);
}
knownProtocols.addAll(inputProtocols);
maxFrameSize = inputMaxFrameSize;
currentDecodingExtension = null;
}

@Override
Expand All @@ -259,9 +270,9 @@ public HandshakeState acceptHandshakeAsServer(ClientHandshake handshakedata)
String requestedExtension = handshakedata.getFieldValue(SEC_WEB_SOCKET_EXTENSIONS);
for (IExtension knownExtension : knownExtensions) {
if (knownExtension.acceptProvidedExtensionAsServer(requestedExtension)) {
extension = knownExtension;
negotiatedExtension = knownExtension;
extensionState = HandshakeState.MATCHED;
log.trace("acceptHandshakeAsServer - Matching extension found: {}", extension);
log.trace("acceptHandshakeAsServer - Matching extension found: {}", negotiatedExtension);
break;
}
}
Expand Down Expand Up @@ -316,9 +327,9 @@ public HandshakeState acceptHandshakeAsClient(ClientHandshake request, ServerHan
String requestedExtension = response.getFieldValue(SEC_WEB_SOCKET_EXTENSIONS);
for (IExtension knownExtension : knownExtensions) {
if (knownExtension.acceptProvidedExtensionAsClient(requestedExtension)) {
extension = knownExtension;
negotiatedExtension = knownExtension;
extensionState = HandshakeState.MATCHED;
log.trace("acceptHandshakeAsClient - Matching extension found: {}", extension);
log.trace("acceptHandshakeAsClient - Matching extension found: {}", negotiatedExtension);
break;
}
}
Expand All @@ -337,7 +348,7 @@ public HandshakeState acceptHandshakeAsClient(ClientHandshake request, ServerHan
* @return the extension which is used or null, if handshake is not yet done
*/
public IExtension getExtension() {
return extension;
return negotiatedExtension;
}

/**
Expand Down Expand Up @@ -562,8 +573,20 @@ private Framedata translateSingleFrame(ByteBuffer buffer)
frame.setRSV3(rsv3);
payload.flip();
frame.setPayload(payload);
getExtension().isFrameValid(frame);
getExtension().decodeFrame(frame);
if (frame.getOpcode() != Opcode.CONTINUOUS) {
// Prioritize the negotiated extension
if (frame.isRSV1() || frame.isRSV2() || frame.isRSV3()) {
currentDecodingExtension = getExtension();
} else {
// No encoded message, so we can use the default one
currentDecodingExtension = defaultExtension;
}
}
if (currentDecodingExtension == null) {
currentDecodingExtension = defaultExtension;
}
currentDecodingExtension.isFrameValid(frame);
currentDecodingExtension.decodeFrame(frame);
if (log.isTraceEnabled()) {
log.trace("afterDecoding({}): {}", frame.getPayloadData().remaining(),
(frame.getPayloadData().remaining() > 1000 ? "too big to display"
Expand Down Expand Up @@ -780,10 +803,10 @@ public List<Framedata> createFrames(String text, boolean mask) {
@Override
public void reset() {
incompleteframe = null;
if (extension != null) {
extension.reset();
if (negotiatedExtension != null) {
negotiatedExtension.reset();
}
extension = new DefaultExtension();
negotiatedExtension = new DefaultExtension();
protocol = null;
}

Expand Down Expand Up @@ -1116,15 +1139,15 @@ public boolean equals(Object o) {
if (maxFrameSize != that.getMaxFrameSize()) {
return false;
}
if (extension != null ? !extension.equals(that.getExtension()) : that.getExtension() != null) {
if (negotiatedExtension != null ? !negotiatedExtension.equals(that.getExtension()) : that.getExtension() != null) {
return false;
}
return protocol != null ? protocol.equals(that.getProtocol()) : that.getProtocol() == null;
}

@Override
public int hashCode() {
int result = extension != null ? extension.hashCode() : 0;
int result = negotiatedExtension != null ? negotiatedExtension.hashCode() : 0;
result = 31 * result + (protocol != null ? protocol.hashCode() : 0);
result = 31 * result + (maxFrameSize ^ (maxFrameSize >>> 32));
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,15 +142,15 @@ public void decodeFrame(Framedata inputFrame) throws InvalidDataException {
return;
}

if (!inputFrame.isRSV1() && inputFrame.getOpcode() != Opcode.CONTINUOUS) {
return;
}

// RSV1 bit must be set only for the first frame.
if (inputFrame.getOpcode() == Opcode.CONTINUOUS && inputFrame.isRSV1()) {
throw new InvalidDataException(CloseFrame.POLICY_VALIDATION,
"RSV1 bit can only be set for the first frame.");
}
// If rsv1 is not set, we dont have a compressed message
if (!inputFrame.isRSV1()) {
return;
}

// Decompressed output buffer.
ByteArrayOutputStream output = new ByteArrayOutputStream();
Expand Down Expand Up @@ -181,11 +181,6 @@ We can check the getRemaining() method to see whether the data we supplied has b
throw new InvalidDataException(CloseFrame.POLICY_VALIDATION, e.getMessage());
}

// RSV1 bit must be cleared after decoding, so that other extensions don't throw an exception.
if (inputFrame.isRSV1()) {
((DataFrame) inputFrame).setRSV1(false);
}

// Set frames payload to the new decompressed data.
((FramedataImpl1) inputFrame)
.setPayload(ByteBuffer.wrap(output.toByteArray(), 0, output.size()));
Expand Down