Skip to content

Integrated an SSLSocketChannel class that allows wss support #101

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 12, 2012
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions build.xml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
<target name="compile">
<mkdir dir="build/classes" />
<mkdir dir="build/examples" />
<javac includeantruntime="false" debug="on" srcdir="src" destdir="build/classes" />
<javac includeantruntime="false" srcdir="example" classpath="build/classes" destdir="build/examples" />
<javac includeantruntime="false" debug="on" srcdir="src" destdir="build/classes" target="1.5" source="1.5" />
<javac includeantruntime="false" srcdir="example" classpath="build/classes" destdir="build/examples" target="1.5" source="1.5" />
</target>

<target name="jar" depends="compile">
Expand Down
Binary file modified dist/WebSocket.jar
Binary file not shown.
156 changes: 156 additions & 0 deletions example/SSLServer.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
package test;

import java.util.List;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;

import java.security.KeyStore;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.SecureRandom;
import java.security.GeneralSecurityException;
import java.security.NoSuchProviderException;
import java.security.InvalidKeyException;
import java.security.Security;
import java.security.SignatureException;

import javax.net.ssl.TrustManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.KeyManagerFactory;

import java.nio.channels.SocketChannel;

import java.net.InetAddress;
import java.net.InetSocketAddress;

import org.java_websocket.WebSocket;
import org.java_websocket.WebSocketServer;
import org.java_websocket.handshake.ClientHandshake;
import org.java_websocket.WebSocketAdapter;
import org.java_websocket.WebSocketImpl;
import org.java_websocket.SSLSocketChannel;
import org.java_websocket.drafts.Draft;

/*
* Create the appropriate websocket server.
*/
public class SSLServer implements WebSocketServer.WebSocketServerFactory
{
private static final String STORETYPE = "JKS";
private static final String KEYSTORE = "keystore.jks";
private static final String STOREPASSWORD = "storepassword";
private static final String KEYPASSWORD = "keypassword";

public static void main(String[] args) throws Exception
{
new SSLServer();
}

private SSLContext sslContext;

void loadFromFile() throws Exception
{
// load up the key store
KeyStore ks = KeyStore.getInstance(STORETYPE);
File kf = new File(KEYSTORE);
ks.load(new FileInputStream(kf), STOREPASSWORD.toCharArray());

KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
kmf.init(ks, KEYPASSWORD.toCharArray());
TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509");
tmf.init(ks);

sslContext = SSLContext.getInstance("TLS");
sslContext.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
}

/*
* Keystore with certificate created like so (in JKS format):
*
keytool -genkey -validity 3650 -keystore "keystore.jks" -storepass "storepassword" -keypass "keypassword" -alias "default" -dname "CN=127.0.0.1, OU=MyOrgUnit, O=MyOrg, L=MyCity, S=MyRegion, C=MyCountry"
*/
SSLServer() throws Exception
{
sslContext = null;
loadFromFile();

// create the web socket server
WebSocketSource wsgateway = new WebSocketSource(8001, InetAddress.getByName("127.0.0.1"));
wsgateway.setWebSocketFactory(this);
wsgateway.start();
}

@Override
public WebSocketImpl createWebSocket( WebSocketAdapter a, Draft d, SocketChannel c ) {
if(sslContext != null) try{
SSLEngine e = sslContext.createSSLEngine();
e.setUseClientMode(false);
return new WebSocketImpl( a, d, new SSLSocketChannel(c, e));
}catch(Exception e1){}
return new WebSocketImpl( a, d, c );
}

@Override
public WebSocketImpl createWebSocket( WebSocketAdapter a, List<Draft> d, SocketChannel c ) {
if(sslContext != null) try{
SSLEngine e = sslContext.createSSLEngine();
e.setUseClientMode(false);
return new WebSocketImpl( a, d, new SSLSocketChannel(c, e)); }catch(Exception e1){}
return new WebSocketImpl( a, d, c );
}

class WebSocketSource extends WebSocketServer
{
private WebSocket handle;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fomojola I am just testing around with ssl trying to simplify things if possible: 98f4a3c

And i am not sure what the 'handle' is for...? It seems to limit the number of simultaneous connections...

But even when i remove the handle as you see in my code my firefox somehow fails to create more that one ssl connection to the example server at a time.

I am probably missing some basic mechanic...

Do you have an explanation for that behaviour?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right: handle was basically a remnant of my application that made it into the test file, since in my app I only allowed 1 connection at a time.
Creating multiple connections should be just fine: just strip out all references to handle and use arg0. I'm not sure what the default backlog is: I think it is 50 or something, but the base WebSocketServer doesn't set it when calling bind, so there may be something there that prevents multiple connections.

WebSocketSource(int port, InetAddress addr)
{
super(new InetSocketAddress(addr, port));
handle = null;
}

@Override
public void onClose(WebSocket arg0, int arg1, String arg2, boolean arg3)
{
System.err.println("---------------------------->Closed");
if(arg0 == handle) handle = null;
}

@Override
public void onError(WebSocket arg0, Exception arg1) {
// TODO Auto-generated method stub
}

@Override
public void onMessage(WebSocket arg0, String arg1)
{
if(arg0 != handle){
arg0.close(org.java_websocket.framing.CloseFrame.NORMAL);
return;
}

System.out.println("--------->["+arg1+"]");
}

@Override
public void onOpen(WebSocket arg0, ClientHandshake arg1)
{
// nothing to see just yet
if(handle == null){
handle = arg0;
}else if(handle != arg0){
arg0.close(org.java_websocket.framing.CloseFrame.NORMAL);
}
}

void done()
{
if(handle != null) handle.close(org.java_websocket.framing.CloseFrame.NORMAL);
}
}
}
204 changes: 204 additions & 0 deletions src/org/java_websocket/SSLSocketChannel.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
/**
* Copyright (C) 2003 Alexander Kout
* Originally from the jFxp project (http://jfxp.sourceforge.net/).
* Copied with permission June 11, 2012 by Femi Omojola (fomojola@ideasynthesis.com).
*/
package org.java_websocket;

import java.net.*;
import java.nio.*;
import java.nio.channels.*;
import javax.net.ssl.*;
import java.io.*;

/**
* Implements the relevant portions of the SocketChannel interface with the SSLEngine wrapper.
*/
public class SSLSocketChannel
{
private ByteBuffer clientIn, clientOut, cTOs, sTOc, wbuf;
private SocketChannel sc;
private SSLEngineResult res;
private SSLEngine sslEngine;
private int SSL;

public SSLSocketChannel(SocketChannel sc, SSLEngine sslEngine) throws IOException
{
this.sc = sc;
this.sslEngine = sslEngine;
SSL = 1;
try {
sslEngine.setEnableSessionCreation(true);
SSLSession session = sslEngine.getSession();
createBuffers(session);
// wrap
clientOut.clear();
sc.write(wrap(clientOut));
while (res.getHandshakeStatus() !=
SSLEngineResult.HandshakeStatus.FINISHED) {
if (res.getHandshakeStatus() ==
SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
// unwrap
sTOc.clear();
while (sc.read(sTOc) < 1)
Thread.sleep(20);
sTOc.flip();
unwrap(sTOc);
if (res.getHandshakeStatus() != SSLEngineResult.HandshakeStatus.FINISHED) {
clientOut.clear();
sc.write(wrap(clientOut));
}
} else if (res.getHandshakeStatus() ==
SSLEngineResult.HandshakeStatus.NEED_WRAP) {
// wrap
clientOut.clear();
sc.write(wrap(clientOut));
} else {Thread.sleep(1000);}
}
clientIn.clear();
clientIn.flip();
SSL = 4;
} catch (Exception e) {
e.printStackTrace(System.out);
SSL = 0;
}
}

private synchronized ByteBuffer wrap(ByteBuffer b) throws SSLException {
cTOs.clear();
res = sslEngine.wrap(b, cTOs);
cTOs.flip();
return cTOs;
}

private synchronized ByteBuffer unwrap(ByteBuffer b) throws SSLException {
clientIn.clear();
int pos;
while (b.hasRemaining()) {
res = sslEngine.unwrap(b, clientIn);
if (res.getHandshakeStatus() ==
SSLEngineResult.HandshakeStatus.NEED_TASK) {
// Task
Runnable task;
while ((task=sslEngine.getDelegatedTask()) != null)
{
task.run();
}
} else if (res.getHandshakeStatus() ==
SSLEngineResult.HandshakeStatus.FINISHED) {
return clientIn;
} else if (res.getStatus() ==
SSLEngineResult.Status.BUFFER_UNDERFLOW) {
return clientIn;
}
}
return clientIn;
}

private void createBuffers(SSLSession session) {

int appBufferMax = session.getApplicationBufferSize();
int netBufferMax = session.getPacketBufferSize();

clientIn = ByteBuffer.allocate(65536);
clientOut = ByteBuffer.allocate(appBufferMax);
wbuf = ByteBuffer.allocate(65536);

cTOs = ByteBuffer.allocate(netBufferMax);
sTOc = ByteBuffer.allocate(netBufferMax);

}

public int write(ByteBuffer src) throws IOException {
if (SSL == 4) {
return sc.write(wrap(src));
}
return sc.write(src);
}

public int read(ByteBuffer dst) throws IOException {
int amount = 0, limit;
if (SSL == 4) {
// test if there was a buffer overflow in dst
if (clientIn.hasRemaining()) {
limit = Math.min(clientIn.remaining(), dst.remaining());
for (int i = 0; i < limit; i++) {
dst.put(clientIn.get());
amount++;
}
return amount;
}
// test if some bytes left from last read (e.g. BUFFER_UNDERFLOW)
if (sTOc.hasRemaining()) {
unwrap(sTOc);
clientIn.flip();
limit = Math.min(clientIn.limit(), dst.remaining());
for (int i = 0; i < limit; i++) {
dst.put(clientIn.get());
amount++;
}
if (res.getStatus() != SSLEngineResult.Status.BUFFER_UNDERFLOW) {
sTOc.clear();
sTOc.flip();
return amount;
}
}
if (!sTOc.hasRemaining())
sTOc.clear();
else
sTOc.compact();

if (sc.read(sTOc) == -1) {
sTOc.clear();
sTOc.flip();
return -1;
}
sTOc.flip();
unwrap(sTOc);
// write in dst
clientIn.flip();
limit = Math.min(clientIn.limit(), dst.remaining());
for (int i = 0; i < limit; i++) {
dst.put(clientIn.get());
amount++;
}
return amount;
}
return sc.read(dst);
}

public boolean isConnected() {
return sc.isConnected();
}

public void close() throws IOException {
if (SSL == 4) {
sslEngine.closeOutbound();
sslEngine.getSession().invalidate();
clientOut.clear();
sc.write(wrap(clientOut));
sc.close();
} else
sc.close();
}

public SelectableChannel configureBlocking(boolean b) throws IOException {
return sc.configureBlocking(b);
}

public boolean connect(SocketAddress remote) throws IOException {
return sc.connect(remote);
}

public boolean finishConnect() throws IOException {
return sc.finishConnect();
}

public Socket socket() {
return sc.socket();
}

public boolean isInboundDone() {
return sslEngine.isInboundDone();
}
}
2 changes: 1 addition & 1 deletion src/org/java_websocket/WebSocket.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public enum Role {
CLIENT, SERVER
}

public static int RCVBUF = 256;
public static int RCVBUF = 16384;

public static/*final*/boolean DEBUG = false; // must be final in the future in order to take advantage of VM optimization

Expand Down
Loading