11package zmq .io .net ;
22
33import java .util .Arrays ;
4- import java .util .HashMap ;
54import java .util .Locale ;
65import java .util .Map ;
6+ import java .util .Optional ;
7+ import java .util .ServiceLoader ;
78import java .util .Set ;
9+ import java .util .concurrent .ConcurrentHashMap ;
810import java .util .function .BiConsumer ;
911import java .util .function .Consumer ;
1012import java .util .stream .Collectors ;
1618import zmq .io .IOThread ;
1719import zmq .io .SessionBase ;
1820import zmq .io .net .Address .IZAddress ;
19- import zmq .io .net .inproc .InprocNetworkProtocolProvider ;
20- import zmq .io .net .ipc .IpcNetworkProtocolProvider ;
21- import zmq .io .net .norm .NormNetworkProtocolProvider ;
22- import zmq .io .net .pgm .EpgmNetworkProtocolProvider ;
23- import zmq .io .net .pgm .PgmNetworkProtocolProvider ;
24- import zmq .io .net .tcp .TcpNetworkProtocolProvider ;
25- import zmq .io .net .tipc .TipcNetworkProtocolProvider ;
2621import zmq .socket .Sockets ;
2722
2823public enum NetProtocol
@@ -69,16 +64,68 @@ public void resolve(Address paddr, boolean ipv6)
6964 },
7065 vmci (true , true );
7166
72- private static final Map <NetProtocol , NetworkProtocolProvider > providers ;
73- static {
74- providers = new HashMap <>(NetProtocol .values ().length );
75- providers .put (NetProtocol .tcp , new TcpNetworkProtocolProvider ());
76- providers .put (NetProtocol .ipc , new IpcNetworkProtocolProvider ());
77- providers .put (NetProtocol .tipc , new TipcNetworkProtocolProvider ());
78- providers .put (NetProtocol .norm , new NormNetworkProtocolProvider ());
79- providers .put (NetProtocol .inproc , new InprocNetworkProtocolProvider ());
80- providers .put (NetProtocol .pgm , new PgmNetworkProtocolProvider ());
81- providers .put (NetProtocol .epgm , new EpgmNetworkProtocolProvider ());
67+ private static final Map <NetProtocol , NetworkProtocolProvider > providers = new ConcurrentHashMap <>();
68+
69+ /**
70+ * @param protocol name
71+ * @throws IllegalArgumentException if the protocol name can be matched to an actual supported protocol
72+ * @return {@link NetProtocol} resolved by name
73+ */
74+ public static NetProtocol getProtocol (String protocol )
75+ {
76+ try {
77+ return valueOf (protocol .toLowerCase (Locale .ENGLISH ));
78+ }
79+ catch (NullPointerException | IllegalArgumentException e ) {
80+ throw new IllegalArgumentException ("Unknown protocol: \" " + protocol + "\" " );
81+ }
82+ }
83+
84+ /**
85+ * <p>Load a {@link NetworkProtocolProvider} and ensure that if a class loader is given, it will be used. If multiple
86+ * instance are returned, the first one will be used.</p>
87+ *
88+ * <p>The purpose of this method is to be able to handle how a {@link NetworkProtocolProvider} service is resolver, by
89+ * tweaking CLASSPATH and class loader.</p>
90+ * @param proto The protocol to search
91+ * @param cl The class loader used to resolve the {@link NetworkProtocolProvider} or null if not required.
92+ */
93+ public static void loadWithClassLoader (NetProtocol proto , ClassLoader cl )
94+ {
95+ Optional .ofNullable (resolveProtocol (proto , cl ))
96+ .ifPresent (npp -> providers .put (proto , npp ));
97+ }
98+
99+ private static NetworkProtocolProvider resolveProtocol (NetProtocol proto , ClassLoader cl )
100+ {
101+ ServiceLoader <NetworkProtocolProvider > serviceLoader = ServiceLoader .load (NetworkProtocolProvider .class , cl );
102+ return serviceLoader .stream ()
103+ .map (ServiceLoader .Provider ::get )
104+ .filter (npp -> npp .isValid () && npp .handleProtocol (proto ) && (cl == null || npp .getClass ().getClassLoader () == cl ))
105+ .findFirst ()
106+ .orElse (null );
107+ }
108+
109+ private static NetworkProtocolProvider resolveProtocol (NetProtocol proto )
110+ {
111+ return resolveProtocol (proto , null );
112+ }
113+
114+ /**
115+ * Install the requested {@link NetworkProtocolProvider}, overwriting the previously installed. Checks are done to
116+ * ensure that the provided protocol is valid
117+ * @param protocol The {@link NetProtocol}
118+ * @param provider The {@link NetworkProtocolProvider} to be installed.
119+ * @throws IllegalArgumentException if the provider is not usable for this protocol
120+ */
121+ public static void install (NetProtocol protocol , NetworkProtocolProvider provider )
122+ {
123+ if (provider .isValid () && provider .handleProtocol (protocol )) {
124+ providers .put (protocol , provider );
125+ }
126+ else {
127+ throw new IllegalArgumentException ("The given provider can't handle " + protocol );
128+ }
82129 }
83130
84131 public final boolean subscribe2all ;
@@ -94,22 +141,9 @@ public void resolve(Address paddr, boolean ipv6)
94141
95142 public boolean isValid ()
96143 {
97- return providers .containsKey (this ) && providers .get (this ).isValid ();
98- }
99-
100- /**
101- * @param protocol name
102- * @throws IllegalArgumentException if the protocol name can be matched to an actual supported protocol
103- * @return
104- */
105- public static NetProtocol getProtocol (String protocol )
106- {
107- try {
108- return valueOf (protocol .toLowerCase (Locale .ENGLISH ));
109- }
110- catch (NullPointerException | IllegalArgumentException e ) {
111- throw new IllegalArgumentException ("Unknown protocol: \" " + protocol + "\" " );
112- }
144+ return Optional .ofNullable (providers .computeIfAbsent (this , NetProtocol ::resolveProtocol ))
145+ .map (NetworkProtocolProvider ::isValid )
146+ .orElse (false );
113147 }
114148
115149 public final boolean compatible (int type )
@@ -119,7 +153,7 @@ public final boolean compatible(int type)
119153
120154 public Listener getListener (IOThread ioThread , SocketBase socket , Options options )
121155 {
122- return providers . get ( this ).getListener (ioThread , socket , options );
156+ return resolve ( ).getListener (ioThread , socket , options );
123157 }
124158
125159 public void resolve (Address paddr , boolean ipv6 )
@@ -129,12 +163,21 @@ public void resolve(Address paddr, boolean ipv6)
129163
130164 public IZAddress zresolve (String addr , boolean ipv6 )
131165 {
132- return providers . get ( this ).zresolve (addr , ipv6 );
166+ return resolve ( ).zresolve (addr , ipv6 );
133167 }
134168
135169 public void startConnecting (Options options , IOThread ioThread , SessionBase session , Address addr ,
136170 boolean delayedStart , Consumer <Own > launchChild , BiConsumer <SessionBase , IEngine > sendAttach )
137171 {
138- providers .get (this ).startConnecting (options , ioThread , session , addr , delayedStart , launchChild , sendAttach );
172+ resolve ().startConnecting (options , ioThread , session , addr , delayedStart , launchChild , sendAttach );
173+ }
174+
175+ private NetworkProtocolProvider resolve ()
176+ {
177+ NetworkProtocolProvider protocolProvider = providers .computeIfAbsent (this , NetProtocol ::resolveProtocol );
178+ if (protocolProvider == null || ! protocolProvider .isValid ()) {
179+ throw new IllegalArgumentException ("Unsupported network protocol " + this );
180+ }
181+ return protocolProvider ;
139182 }
140183}
0 commit comments