Skip to content

Commit b855d9c

Browse files
committed
Implemented as ServiceLoader, with helpers to tweak resolution.
1 parent 7a9c5a1 commit b855d9c

File tree

3 files changed

+100
-41
lines changed

3 files changed

+100
-41
lines changed

src/main/java/zmq/io/net/NetProtocol.java

Lines changed: 80 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
package zmq.io.net;
22

33
import java.util.Arrays;
4-
import java.util.HashMap;
54
import java.util.Locale;
65
import java.util.Map;
6+
import java.util.Optional;
7+
import java.util.ServiceLoader;
78
import java.util.Set;
9+
import java.util.concurrent.ConcurrentHashMap;
810
import java.util.function.BiConsumer;
911
import java.util.function.Consumer;
1012
import java.util.stream.Collectors;
@@ -16,13 +18,6 @@
1618
import zmq.io.IOThread;
1719
import zmq.io.SessionBase;
1820
import zmq.io.net.Address.IZAddress;
19-
import zmq.io.net.inproc.InprocNetworkProtocolProvider;
20-
import zmq.io.net.ipc.IpcNetworkProtocolProvider;
21-
import zmq.io.net.norm.NormNetworkProtocolProvider;
22-
import zmq.io.net.pgm.EpgmNetworkProtocolProvider;
23-
import zmq.io.net.pgm.PgmNetworkProtocolProvider;
24-
import zmq.io.net.tcp.TcpNetworkProtocolProvider;
25-
import zmq.io.net.tipc.TipcNetworkProtocolProvider;
2621
import zmq.socket.Sockets;
2722

2823
public enum NetProtocol
@@ -69,16 +64,68 @@ public void resolve(Address paddr, boolean ipv6)
6964
},
7065
vmci(true, true);
7166

72-
private static final Map<NetProtocol, NetworkProtocolProvider> providers;
73-
static {
74-
providers = new HashMap<>(NetProtocol.values().length);
75-
providers.put(NetProtocol.tcp, new TcpNetworkProtocolProvider());
76-
providers.put(NetProtocol.ipc, new IpcNetworkProtocolProvider());
77-
providers.put(NetProtocol.tipc, new TipcNetworkProtocolProvider());
78-
providers.put(NetProtocol.norm, new NormNetworkProtocolProvider());
79-
providers.put(NetProtocol.inproc, new InprocNetworkProtocolProvider());
80-
providers.put(NetProtocol.pgm, new PgmNetworkProtocolProvider());
81-
providers.put(NetProtocol.epgm, new EpgmNetworkProtocolProvider());
67+
private static final Map<NetProtocol, NetworkProtocolProvider> providers = new ConcurrentHashMap<>();
68+
69+
/**
70+
* @param protocol name
71+
* @throws IllegalArgumentException if the protocol name can be matched to an actual supported protocol
72+
* @return {@link NetProtocol} resolved by name
73+
*/
74+
public static NetProtocol getProtocol(String protocol)
75+
{
76+
try {
77+
return valueOf(protocol.toLowerCase(Locale.ENGLISH));
78+
}
79+
catch (NullPointerException | IllegalArgumentException e) {
80+
throw new IllegalArgumentException("Unknown protocol: \"" + protocol + "\"");
81+
}
82+
}
83+
84+
/**
85+
* <p>Load a {@link NetworkProtocolProvider} and ensure that if a class loader is given, it will be used. If multiple
86+
* instance are returned, the first one will be used.</p>
87+
*
88+
* <p>The purpose of this method is to be able to handle how a {@link NetworkProtocolProvider} service is resolver, by
89+
* tweaking CLASSPATH and class loader.</p>
90+
* @param proto The protocol to search
91+
* @param cl The class loader used to resolve the {@link NetworkProtocolProvider} or null if not required.
92+
*/
93+
public static void loadWithClassLoader(NetProtocol proto, ClassLoader cl)
94+
{
95+
Optional.ofNullable(resolveProtocol(proto, cl))
96+
.ifPresent(npp -> providers.put(proto, npp));
97+
}
98+
99+
private static NetworkProtocolProvider resolveProtocol(NetProtocol proto, ClassLoader cl)
100+
{
101+
ServiceLoader<NetworkProtocolProvider> serviceLoader = ServiceLoader.load(NetworkProtocolProvider.class, cl);
102+
return serviceLoader.stream()
103+
.map(ServiceLoader.Provider::get)
104+
.filter(npp -> npp.isValid() && npp.handleProtocol(proto) && (cl == null || npp.getClass().getClassLoader() == cl))
105+
.findFirst()
106+
.orElse(null);
107+
}
108+
109+
private static NetworkProtocolProvider resolveProtocol(NetProtocol proto)
110+
{
111+
return resolveProtocol(proto, null);
112+
}
113+
114+
/**
115+
* Install the requested {@link NetworkProtocolProvider}, overwriting the previously installed. Checks are done to
116+
* ensure that the provided protocol is valid
117+
* @param protocol The {@link NetProtocol}
118+
* @param provider The {@link NetworkProtocolProvider} to be installed.
119+
* @throws IllegalArgumentException if the provider is not usable for this protocol
120+
*/
121+
public static void install(NetProtocol protocol, NetworkProtocolProvider provider)
122+
{
123+
if (provider.isValid() && provider.handleProtocol(protocol)) {
124+
providers.put(protocol, provider);
125+
}
126+
else {
127+
throw new IllegalArgumentException("The given provider can't handle " + protocol);
128+
}
82129
}
83130

84131
public final boolean subscribe2all;
@@ -94,22 +141,9 @@ public void resolve(Address paddr, boolean ipv6)
94141

95142
public boolean isValid()
96143
{
97-
return providers.containsKey(this) && providers.get(this).isValid();
98-
}
99-
100-
/**
101-
* @param protocol name
102-
* @throws IllegalArgumentException if the protocol name can be matched to an actual supported protocol
103-
* @return
104-
*/
105-
public static NetProtocol getProtocol(String protocol)
106-
{
107-
try {
108-
return valueOf(protocol.toLowerCase(Locale.ENGLISH));
109-
}
110-
catch (NullPointerException | IllegalArgumentException e) {
111-
throw new IllegalArgumentException("Unknown protocol: \"" + protocol + "\"");
112-
}
144+
return Optional.ofNullable(providers.computeIfAbsent(this, NetProtocol::resolveProtocol))
145+
.map(NetworkProtocolProvider::isValid)
146+
.orElse(false);
113147
}
114148

115149
public final boolean compatible(int type)
@@ -119,7 +153,7 @@ public final boolean compatible(int type)
119153

120154
public Listener getListener(IOThread ioThread, SocketBase socket, Options options)
121155
{
122-
return providers.get(this).getListener(ioThread, socket, options);
156+
return resolve().getListener(ioThread, socket, options);
123157
}
124158

125159
public void resolve(Address paddr, boolean ipv6)
@@ -129,12 +163,21 @@ public void resolve(Address paddr, boolean ipv6)
129163

130164
public IZAddress zresolve(String addr, boolean ipv6)
131165
{
132-
return providers.get(this).zresolve(addr, ipv6);
166+
return resolve().zresolve(addr, ipv6);
133167
}
134168

135169
public void startConnecting(Options options, IOThread ioThread, SessionBase session, Address addr,
136170
boolean delayedStart, Consumer<Own> launchChild, BiConsumer<SessionBase, IEngine> sendAttach)
137171
{
138-
providers.get(this).startConnecting(options, ioThread, session, addr, delayedStart, launchChild, sendAttach);
172+
resolve().startConnecting(options, ioThread, session, addr, delayedStart, launchChild, sendAttach);
173+
}
174+
175+
private NetworkProtocolProvider resolve()
176+
{
177+
NetworkProtocolProvider protocolProvider = providers.computeIfAbsent(this, NetProtocol::resolveProtocol);
178+
if (protocolProvider == null || ! protocolProvider.isValid()) {
179+
throw new IllegalArgumentException("Unsupported network protocol " + this);
180+
}
181+
return protocolProvider;
139182
}
140183
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
zmq.io.net.inproc.InprocNetworkProtocolProvider
2+
zmq.io.net.ipc.IpcNetworkProtocolProvider
3+
zmq.io.net.norm.NormNetworkProtocolProvider
4+
zmq.io.net.pgm.EpgmNetworkProtocolProvider
5+
zmq.io.net.pgm.PgmNetworkProtocolProvider
6+
zmq.io.net.tcp.TcpNetworkProtocolProvider
7+
zmq.io.net.tipc.TipcNetworkProtocolProvider

src/test/java/zmq/io/AbstractProtocolVersion.java

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
11
package zmq.io;
22

3-
import static org.hamcrest.CoreMatchers.is;
4-
import static org.hamcrest.CoreMatchers.notNullValue;
5-
import static org.hamcrest.MatcherAssert.assertThat;
6-
73
import java.io.IOException;
84
import java.io.InputStream;
95
import java.io.OutputStream;
@@ -12,6 +8,7 @@
128
import java.util.ArrayList;
139
import java.util.Arrays;
1410
import java.util.List;
11+
import java.util.concurrent.atomic.AtomicReference;
1512

1613
import zmq.Ctx;
1714
import zmq.Msg;
@@ -21,9 +18,15 @@
2118
import zmq.ZMQ.Event;
2219
import zmq.util.TestUtils;
2320

21+
import static org.hamcrest.CoreMatchers.is;
22+
import static org.hamcrest.CoreMatchers.notNullValue;
23+
import static org.hamcrest.CoreMatchers.nullValue;
24+
import static org.hamcrest.MatcherAssert.assertThat;
25+
2426
public abstract class AbstractProtocolVersion
2527
{
2628
protected static final int REPETITIONS = 1000;
29+
private static final AtomicReference<Throwable> monitorFailure = new AtomicReference<>();
2730

2831
static class SocketMonitor extends Thread
2932
{
@@ -35,6 +38,11 @@ public SocketMonitor(Ctx ctx, String monitorAddr)
3538
{
3639
this.ctx = ctx;
3740
this.monitorAddr = monitorAddr;
41+
monitorFailure.set(null);
42+
this.setUncaughtExceptionHandler((t, ex) -> {
43+
ex.printStackTrace();
44+
monitorFailure.set(ex);
45+
});
3846
}
3947

4048
@Override
@@ -87,6 +95,7 @@ protected byte[] assertProtocolVersion(int version, List<ByteBuffer> raws, Strin
8795
for (ByteBuffer raw : raws) {
8896
out.write(raw.array());
8997
}
98+
assertThat(monitorFailure.get(), nullValue());
9099

91100
Msg msg = ZMQ.recv(receiver, 0);
92101
assertThat(msg, notNullValue());

0 commit comments

Comments
 (0)