Skip to content

Commit

Permalink
feat: Add complete server/client TLS support (#158)
Browse files Browse the repository at this point in the history
BREAKING CHANGE: TLS client API now matches NodeJS official tls API.
  • Loading branch information
Rapsssito authored Aug 21, 2022
1 parent 755d7cb commit 3264f44
Show file tree
Hide file tree
Showing 36 changed files with 2,609 additions and 1,445 deletions.
252 changes: 193 additions & 59 deletions README.md

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import android.annotation.SuppressLint;
import android.content.Context;

import androidx.annotation.NonNull;
import androidx.annotation.RawRes;

import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
Expand All @@ -13,14 +16,16 @@
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;

import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLServerSocketFactory;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import javax.net.ssl.X509ExtendedKeyManager;
import javax.net.ssl.X509TrustManager;

import androidx.annotation.NonNull;
import androidx.annotation.RawRes;

final class SSLCertificateHelper {
/**
Expand All @@ -34,6 +39,23 @@ static SSLSocketFactory createBlindSocketFactory() throws GeneralSecurityExcepti
return ctx.getSocketFactory();
}

static SSLServerSocketFactory createServerSocketFactory(Context context, @NonNull final String keyStoreResourceUri) throws GeneralSecurityException, IOException {
char[] password = "".toCharArray();

InputStream keyStoreInput = getRawResourceStream(context, keyStoreResourceUri);
KeyStore keyStore = KeyStore.getInstance("PKCS12");
keyStore.load(keyStoreInput, password);
keyStoreInput.close();

KeyManagerFactory keyManagerFactory = KeyManagerFactory.getInstance("X509");
keyManagerFactory.init(keyStore, password);

SSLContext sslContext = SSLContext.getInstance("TLS");
sslContext.init(keyManagerFactory.getKeyManagers(), new TrustManager[]{new BlindTrustManager()}, null);

return sslContext.getServerSocketFactory();
}

/**
* Creates an SSLSocketFactory instance for use with the CA provided in the resource file.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.asterinet.react.tcpsocket;

import android.util.Base64;
import android.util.Log;

import com.facebook.react.bridge.Arguments;
import com.facebook.react.bridge.ReactContext;
Expand All @@ -24,6 +25,14 @@ public TcpEventListener(final ReactContext reactContext) {
}

public void onConnection(int serverId, int clientId, Socket socket) {
onSocketConnection("connection", serverId, clientId, socket);
}

public void onSecureConnection(int serverId, int clientId, Socket socket) {
onSocketConnection("secureConnection", serverId, clientId, socket);
}

private void onSocketConnection(String connectionType, int serverId, int clientId, Socket socket) {
WritableMap eventParams = Arguments.createMap();
eventParams.putInt("id", serverId);

Expand All @@ -42,7 +51,7 @@ public void onConnection(int serverId, int clientId, Socket socket) {
infoParams.putMap("connection", connectionParams);
eventParams.putMap("info", infoParams);

sendEvent("connection", eventParams);
sendEvent(connectionType, eventParams);
}

public void onConnect(int id, TcpSocketClient client) {
Expand Down Expand Up @@ -83,7 +92,12 @@ public void onData(int id, byte[] data) {
sendEvent("data", eventParams);
}

public void onWritten(int id, int msgId, @Nullable String error) {
public void onWritten(int id, int msgId, @Nullable Exception e) {
String error = null;
if (e != null) {
Log.e(TcpSocketModule.TAG, "Exception on socket " + id, e);
error = e.getMessage();
}
WritableMap eventParams = Arguments.createMap();
eventParams.putInt("id", id);
eventParams.putInt("msgId", msgId);
Expand All @@ -92,18 +106,20 @@ public void onWritten(int id, int msgId, @Nullable String error) {
sendEvent("written", eventParams);
}

public void onClose(int id, String error) {
if (error != null) {
onError(id, error);
public void onClose(int id, Exception e) {
if (e != null) {
onError(id, e);
}
WritableMap eventParams = Arguments.createMap();
eventParams.putInt("id", id);
eventParams.putBoolean("hadError", error != null);
eventParams.putBoolean("hadError", e != null);

sendEvent("close", eventParams);
}

public void onError(int id, String error) {
public void onError(int id, Exception e) {
Log.e(TcpSocketModule.TAG, "Exception on socket " + id, e);
String error = e.getMessage();
WritableMap eventParams = Arguments.createMap();
eventParams.putInt("id", id);
eventParams.putString("error", error);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

import javax.net.SocketFactory;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;

Expand All @@ -25,6 +24,7 @@ class TcpSocketClient extends TcpSocket {
private final TcpEventListener receiverListener;
private TcpReceiverTask receiverTask;
private Socket socket;
private boolean closed = true;

TcpSocketClient(TcpEventListener receiverListener, Integer id, Socket socket) {
super(id);
Expand All @@ -38,20 +38,12 @@ public Socket getSocket() {
return socket;
}

public void connect(Context context, String address, final Integer port, ReadableMap options, Network network) throws IOException, GeneralSecurityException {
public void connect(Context context, String address, final Integer port, ReadableMap options, Network network, ReadableMap tlsOptions) throws IOException, GeneralSecurityException {
if (socket != null) throw new IOException("Already connected");
final boolean isTls = options.hasKey("tls") && options.getBoolean("tls");
if (isTls) {
SocketFactory sf;
if (options.hasKey("tlsCheckValidity") && !options.getBoolean("tlsCheckValidity")) {
sf = SSLCertificateHelper.createBlindSocketFactory();
} else {
final String customTlsCert = options.hasKey("tlsCert") ? options.getString("tlsCert") : null;
sf = customTlsCert != null ? SSLCertificateHelper.createCustomTrustedSocketFactory(context, customTlsCert) : SSLSocketFactory.getDefault();
}
final SSLSocket sslSocket = (SSLSocket) sf.createSocket();
sslSocket.setUseClientMode(true);
socket = sslSocket;
if (tlsOptions != null) {
SSLSocketFactory ssf = getSSLSocketFactory(context, tlsOptions);
socket = ssf.createSocket();
((SSLSocket) socket).setUseClientMode(true);
} else {
socket = new Socket();
}
Expand All @@ -73,10 +65,30 @@ public void connect(Context context, String address, final Integer port, Readabl
// bind
socket.bind(new InetSocketAddress(localInetAddress, localPort));
socket.connect(new InetSocketAddress(remoteInetAddress, port));
if (isTls) ((SSLSocket) socket).startHandshake();
if (socket instanceof SSLSocket) ((SSLSocket) socket).startHandshake();
startListening();
}

public void startTLS(Context context, ReadableMap tlsOptions) throws IOException, GeneralSecurityException {
if (socket instanceof SSLSocket) return;
SSLSocketFactory ssf = getSSLSocketFactory(context, tlsOptions);
SSLSocket sslSocket = (SSLSocket) ssf.createSocket(socket, socket.getInetAddress().getHostAddress(), socket.getPort(), true);
sslSocket.setUseClientMode(true);
sslSocket.startHandshake();
socket = sslSocket;
}

private SSLSocketFactory getSSLSocketFactory(Context context, ReadableMap tlsOptions) throws GeneralSecurityException, IOException {
SSLSocketFactory ssf;
if (tlsOptions.hasKey("rejectUnauthorized") && !tlsOptions.getBoolean("rejectUnauthorized")) {
ssf = SSLCertificateHelper.createBlindSocketFactory();
} else {
final String customTlsCert = tlsOptions.hasKey("ca") ? tlsOptions.getString("ca") : null;
ssf = customTlsCert != null ? SSLCertificateHelper.createCustomTrustedSocketFactory(context, customTlsCert) : (SSLSocketFactory) SSLSocketFactory.getDefault();
}
return ssf;
}

public void startListening() {
receiverTask = new TcpReceiverTask(this, receiverListener);
listenExecutor.execute(receiverTask);
Expand All @@ -95,8 +107,8 @@ public void run() {
socket.getOutputStream().write(data);
receiverListener.onWritten(getId(), msgId, null);
} catch (IOException e) {
receiverListener.onWritten(getId(), msgId, e.toString());
receiverListener.onError(getId(), e.toString());
receiverListener.onWritten(getId(), msgId, e);
receiverListener.onError(getId(), e);
}
}
});
Expand All @@ -109,12 +121,13 @@ public void destroy() {
try {
// close the socket
if (socket != null && !socket.isClosed()) {
closed = true;
socket.close();
receiverListener.onClose(getId(), null);
socket = null;
}
} catch (IOException e) {
receiverListener.onClose(getId(), e.getMessage());
receiverListener.onClose(getId(), e);
}
}

Expand Down Expand Up @@ -183,8 +196,8 @@ public void run() {
}
}
} catch (IOException | InterruptedException ioe) {
if (receiverListener != null && !socket.isClosed()) {
receiverListener.onError(socketId, ioe.getMessage());
if (receiverListener != null && !socket.isClosed() && !clientSocket.closed) {
receiverListener.onError(socketId, ioe);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
import android.annotation.SuppressLint;
import android.content.Context;
import android.net.ConnectivityManager;
import android.net.Network;
import android.net.NetworkCapabilities;
import android.net.NetworkRequest;
import android.util.Base64;
import android.net.Network;

import androidx.annotation.NonNull;
import androidx.annotation.Nullable;

import com.facebook.react.bridge.ReactApplicationContext;
import com.facebook.react.bridge.ReactContextBaseJavaModule;
Expand All @@ -22,14 +25,12 @@
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

import androidx.annotation.NonNull;
import androidx.annotation.Nullable;

public class TcpSocketModule extends ReactContextBaseJavaModule {
private static final String TAG = "TcpSockets";
public static final String TAG = "TcpSockets";
private static final int N_THREADS = 2;
private final ReactApplicationContext mReactContext;
private final ConcurrentHashMap<Integer, TcpSocket> socketMap = new ConcurrentHashMap<>();
private final ConcurrentHashMap<Integer, ReadableMap> pendingTLS = new ConcurrentHashMap<>();
private final ConcurrentHashMap<String, Network> mNetworkMap = new ConcurrentHashMap<>();
private final CurrentNetwork currentNetwork = new CurrentNetwork();
private final ExecutorService executorService = Executors.newFixedThreadPool(N_THREADS);
Expand Down Expand Up @@ -68,7 +69,7 @@ public void connect(@NonNull final Integer cId, @NonNull final String host, @Non
@Override
public void run() {
if (socketMap.get(cId) != null) {
tcpEvtListener.onError(cId, TAG + "createSocket called twice with the same id.");
tcpEvtListener.onError(cId, new Exception("connect() called twice with the same id."));
return;
}
try {
Expand All @@ -78,15 +79,33 @@ public void run() {
selectNetwork(iface, localAddress);
TcpSocketClient client = new TcpSocketClient(tcpEvtListener, cId, null);
socketMap.put(cId, client);
client.connect(mReactContext, host, port, options, currentNetwork.getNetwork());
ReadableMap tlsOptions = pendingTLS.get(cId);
client.connect(mReactContext, host, port, options, currentNetwork.getNetwork(), tlsOptions);
tcpEvtListener.onConnect(cId, client);
} catch (Exception e) {
tcpEvtListener.onError(cId, e.getMessage());
tcpEvtListener.onError(cId, e);
}
}
});
}

@SuppressLint("StaticFieldLeak")
@SuppressWarnings("unused")
@ReactMethod
public void startTLS(final int cId, @NonNull final ReadableMap tlsOptions) {
TcpSocketClient socketClient = (TcpSocketClient) socketMap.get(cId);
// Not yet connected
if (socketClient == null) {
pendingTLS.put(cId, tlsOptions);
} else {
try {
socketClient.startTLS(mReactContext, tlsOptions);
} catch (Exception e) {
tcpEvtListener.onError(cId, e);
}
}
}

@SuppressLint("StaticFieldLeak")
@SuppressWarnings("unused")
@ReactMethod
Expand Down Expand Up @@ -137,11 +156,11 @@ public void listen(final Integer cId, final ReadableMap options) {
@Override
public void run() {
try {
TcpSocketServer server = new TcpSocketServer(socketMap, tcpEvtListener, cId, options);
TcpSocketServer server = new TcpSocketServer(mReactContext, socketMap, tcpEvtListener, cId, options);
socketMap.put(cId, server);
tcpEvtListener.onListen(cId, server);
} catch (Exception uhe) {
tcpEvtListener.onError(cId, uhe.getMessage());
tcpEvtListener.onError(cId, uhe);
}
}
});
Expand All @@ -154,7 +173,7 @@ public void setNoDelay(@NonNull final Integer cId, final boolean noDelay) {
try {
client.setNoDelay(noDelay);
} catch (IOException e) {
tcpEvtListener.onError(cId, e.getMessage());
tcpEvtListener.onError(cId, e);
}
}

Expand All @@ -165,7 +184,7 @@ public void setKeepAlive(@NonNull final Integer cId, final boolean enable, final
try {
client.setKeepAlive(enable, initialDelay);
} catch (IOException e) {
tcpEvtListener.onError(cId, e.getMessage());
tcpEvtListener.onError(cId, e);
}
}

Expand All @@ -182,7 +201,7 @@ public void resume(final int cId) {
TcpSocketClient client = getTcpClient(cId);
client.resume();
}

@SuppressWarnings("unused")
@ReactMethod
public void addListener(String eventName) {
Expand Down Expand Up @@ -260,21 +279,21 @@ private void selectNetwork(@Nullable final String iface, @Nullable final String
private TcpSocketClient getTcpClient(final int id) {
TcpSocket socket = socketMap.get(id);
if (socket == null) {
throw new IllegalArgumentException(TAG + "No socket with id " + id);
throw new IllegalArgumentException("No socket with id " + id);
}
if (!(socket instanceof TcpSocketClient)) {
throw new IllegalArgumentException(TAG + "Socket with id " + id + " is not a client");
throw new IllegalArgumentException("Socket with id " + id + " is not a client");
}
return (TcpSocketClient) socket;
}

private TcpSocketServer getTcpServer(final int id) {
TcpSocket socket = socketMap.get(id);
if (socket == null) {
throw new IllegalArgumentException(TAG + "No socket with id " + id);
throw new IllegalArgumentException("No server socket with id " + id);
}
if (!(socket instanceof TcpSocketServer)) {
throw new IllegalArgumentException(TAG + "Socket with id " + id + " is not a server");
throw new IllegalArgumentException("Server socket with id " + id + " is not a server");
}
return (TcpSocketServer) socket;
}
Expand Down
Loading

0 comments on commit 3264f44

Please sign in to comment.