Skip to content

INT-4366: Fix MulticastSendingMessageHandler #2329

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 19, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2001-2016 the original author or authors.
* Copyright 2001-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -38,6 +38,8 @@
* determine success.
*
* @author Gary Russell
* @author Artem Bilan
*
* @since 2.0
*/
public class MulticastSendingMessageHandler extends UnicastSendingMessageHandler {
Expand Down Expand Up @@ -126,49 +128,45 @@ public MulticastSendingMessageHandler(String destinationExpression) {

@Override
protected DatagramSocket getSocket() throws IOException {
if (this.getTheSocket() == null) {
if (this.multicastSocket == null) {
synchronized (this) {
createSocket();
if (this.multicastSocket == null) {
createSocket();
}
}
}
return this.getTheSocket();
return getTheSocket();
}

private void createSocket() throws IOException {
if (this.getTheSocket() == null) {
MulticastSocket socket;
if (this.isAcknowledge()) {
int ackPort = this.getAckPort();
if (this.localAddress == null) {
socket = ackPort == 0 ? new MulticastSocket() : new MulticastSocket(ackPort);
}
else {
InetAddress whichNic = InetAddress.getByName(this.localAddress);
socket = new MulticastSocket(new InetSocketAddress(whichNic, ackPort));
}
if (getSoReceiveBufferSize() > 0) {
socket.setReceiveBufferSize(this.getSoReceiveBufferSize());
}
if (logger.isDebugEnabled()) {
logger.debug("Listening for acks on port: " + socket.getLocalPort());
}
setSocket(socket);
updateAckAddress();
MulticastSocket socket;
if (isAcknowledge()) {
int ackPort = getAckPort();
if (this.localAddress == null) {
socket = ackPort == 0 ? new MulticastSocket() : new MulticastSocket(ackPort);
}
else {
socket = new MulticastSocket();
setSocket(socket);
InetAddress whichNic = InetAddress.getByName(this.localAddress);
socket = new MulticastSocket(new InetSocketAddress(whichNic, ackPort));
}
if (this.timeToLive >= 0) {
socket.setTimeToLive(this.timeToLive);
if (getSoReceiveBufferSize() > 0) {
socket.setReceiveBufferSize(getSoReceiveBufferSize());
}
setSocketAttributes(socket);
if (this.localAddress != null) {
InetAddress whichNic = InetAddress.getByName(this.localAddress);
socket.setInterface(whichNic);
if (logger.isDebugEnabled()) {
logger.debug("Listening for acks on port: " + socket.getLocalPort());
}
this.multicastSocket = socket;
setSocket(socket);
updateAckAddress();
}
else {
socket = new MulticastSocket();
setSocket(socket);
}
if (this.timeToLive >= 0) {
socket.setTimeToLive(this.timeToLive);
}
setSocketAttributes(socket);
this.multicastSocket = socket;
}


Expand All @@ -178,7 +176,7 @@ private void createSocket() throws IOException {
* @param minAcksForSuccess The minimum number of acks that will represent success.
*/
public void setMinAcksForSuccess(int minAcksForSuccess) {
this.setAckCounter(minAcksForSuccess);
setAckCounter(minAcksForSuccess);
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2016 the original author or authors.
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -28,7 +28,6 @@
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
Expand All @@ -39,6 +38,7 @@
import org.junit.Test;

import org.springframework.beans.factory.BeanFactory;
import org.springframework.core.task.SimpleAsyncTaskExecutor;
import org.springframework.integration.channel.DirectChannel;
import org.springframework.integration.channel.QueueChannel;
import org.springframework.integration.handler.ServiceActivatingHandler;
Expand All @@ -56,6 +56,8 @@

/**
* @author Gary Russell
* @author Artem Bilan
*
* @since 2.0
*/
public class TcpInboundGatewayTests {
Expand Down Expand Up @@ -119,30 +121,31 @@ public void testNetClientMode() throws Exception {
final CountDownLatch latch2 = new CountDownLatch(1);
final CountDownLatch latch3 = new CountDownLatch(1);
final AtomicBoolean done = new AtomicBoolean();
Executors.newSingleThreadExecutor().execute(() -> {
try {
ServerSocket server = ServerSocketFactory.getDefault().createServerSocket(0, 10);
port.set(server.getLocalPort());
latch1.countDown();
Socket socket = server.accept();
socket.getOutputStream().write("Test1\r\nTest2\r\n".getBytes());
byte[] bytes = new byte[12];
readFully(socket.getInputStream(), bytes);
assertEquals("Echo:Test1\r\n", new String(bytes));
readFully(socket.getInputStream(), bytes);
assertEquals("Echo:Test2\r\n", new String(bytes));
latch2.await();
socket.close();
server.close();
done.set(true);
latch3.countDown();
}
catch (Exception e) {
if (!done.get()) {
e.printStackTrace();
}
}
});
new SimpleAsyncTaskExecutor()
.execute(() -> {
try {
ServerSocket server = ServerSocketFactory.getDefault().createServerSocket(0, 10);
port.set(server.getLocalPort());
latch1.countDown();
Socket socket = server.accept();
socket.getOutputStream().write("Test1\r\nTest2\r\n".getBytes());
byte[] bytes = new byte[12];
readFully(socket.getInputStream(), bytes);
assertEquals("Echo:Test1\r\n", new String(bytes));
readFully(socket.getInputStream(), bytes);
assertEquals("Echo:Test2\r\n", new String(bytes));
latch2.await();
socket.close();
server.close();
done.set(true);
latch3.countDown();
}
catch (Exception e) {
if (!done.get()) {
e.printStackTrace();
}
}
});
assertTrue(latch1.await(10, TimeUnit.SECONDS));
AbstractClientConnectionFactory ccf = new TcpNetClientConnectionFactory("localhost", port.get());
ccf.setSingleUse(false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
Expand All @@ -61,6 +60,8 @@
import org.springframework.beans.factory.BeanFactory;
import org.springframework.core.serializer.DefaultDeserializer;
import org.springframework.core.serializer.DefaultSerializer;
import org.springframework.core.task.AsyncTaskExecutor;
import org.springframework.core.task.SimpleAsyncTaskExecutor;
import org.springframework.expression.EvaluationContext;
import org.springframework.expression.Expression;
import org.springframework.expression.spel.standard.SpelExpressionParser;
Expand Down Expand Up @@ -90,6 +91,8 @@ public class TcpOutboundGatewayTests {

private static final Log logger = LogFactory.getLog(TcpOutboundGatewayTests.class);

private AsyncTaskExecutor executor = new SimpleAsyncTaskExecutor();

@ClassRule
public static LongRunningIntegrationTest longTests = new LongRunningIntegrationTest();

Expand All @@ -101,13 +104,13 @@ public class TcpOutboundGatewayTests {
public void testGoodNetSingle() throws Exception {
final CountDownLatch latch = new CountDownLatch(1);
final AtomicBoolean done = new AtomicBoolean();
final AtomicReference<ServerSocket> serverSocket = new AtomicReference<ServerSocket>();
Executors.newSingleThreadExecutor().execute(() -> {
final AtomicReference<ServerSocket> serverSocket = new AtomicReference<>();
this.executor.execute(() -> {
try {
ServerSocket server = ServerSocketFactory.getDefault().createServerSocket(0, 100);
serverSocket.set(server);
latch.countDown();
List<Socket> sockets = new ArrayList<Socket>();
List<Socket> sockets = new ArrayList<>();
int i = 0;
while (true) {
Socket socket = server.accept();
Expand Down Expand Up @@ -165,8 +168,8 @@ public void testGoodNetSingle() throws Exception {
public void testGoodNetMultiplex() throws Exception {
final CountDownLatch latch = new CountDownLatch(1);
final AtomicBoolean done = new AtomicBoolean();
final AtomicReference<ServerSocket> serverSocket = new AtomicReference<ServerSocket>();
Executors.newSingleThreadExecutor().execute(() -> {
final AtomicReference<ServerSocket> serverSocket = new AtomicReference<>();
this.executor.execute(() -> {
try {
ServerSocket server = ServerSocketFactory.getDefault().createServerSocket(0, 10);
serverSocket.set(server);
Expand Down Expand Up @@ -220,8 +223,8 @@ public void testGoodNetMultiplex() throws Exception {
public void testGoodNetTimeout() throws Exception {
final CountDownLatch latch = new CountDownLatch(1);
final AtomicBoolean done = new AtomicBoolean();
final AtomicReference<ServerSocket> serverSocket = new AtomicReference<ServerSocket>();
Executors.newSingleThreadExecutor().execute(() -> {
final AtomicReference<ServerSocket> serverSocket = new AtomicReference<>();
this.executor.execute(() -> {
try {
ServerSocket server = ServerSocketFactory.getDefault().createServerSocket(0);
serverSocket.set(server);
Expand Down Expand Up @@ -260,12 +263,12 @@ public void testGoodNetTimeout() throws Exception {
Future<Integer>[] results = (Future<Integer>[]) new Future<?>[2];
for (int i = 0; i < 2; i++) {
final int j = i;
results[j] = (Executors.newSingleThreadExecutor().submit(() -> {
results[j] = (this.executor.submit(() -> {
gateway.handleMessage(MessageBuilder.withPayload("Test" + j).build());
return 0;
}));
}
Set<String> replies = new HashSet<String>();
Set<String> replies = new HashSet<>();
int timeouts = 0;
for (int i = 0; i < 2; i++) {
try {
Expand Down Expand Up @@ -344,7 +347,7 @@ private void testGoodNetGWTimeoutGuts(final int port, AbstractClientConnectionFa
final AtomicReference<String> lastReceived = new AtomicReference<String>();
final CountDownLatch serverLatch = new CountDownLatch(2);

Executors.newSingleThreadExecutor().execute(() -> {
this.executor.execute(() -> {
try {
latch.countDown();
int i = 0;
Expand Down Expand Up @@ -398,7 +401,7 @@ private void testGoodNetGWTimeoutGuts(final int port, AbstractClientConnectionFa

for (int i = 0; i < 2; i++) {
final int j = i;
results[j] = (Executors.newSingleThreadExecutor().submit(() -> {
results[j] = (this.executor.submit(() -> {
gateway.handleMessage(MessageBuilder.withPayload("Test" + j).build());
return j;
}));
Expand Down Expand Up @@ -442,7 +445,7 @@ public void testCachingFailover() throws Exception {
final AtomicBoolean done = new AtomicBoolean();
final CountDownLatch serverLatch = new CountDownLatch(1);

Executors.newSingleThreadExecutor().execute(() -> {
this.executor.execute(() -> {
try {
ServerSocket server = ServerSocketFactory.getDefault().createServerSocket(0);
serverSocket.set(server);
Expand Down Expand Up @@ -517,12 +520,12 @@ public void testCachingFailover() throws Exception {

@Test
public void testFailoverCached() throws Exception {
final AtomicReference<ServerSocket> serverSocket = new AtomicReference<ServerSocket>();
final AtomicReference<ServerSocket> serverSocket = new AtomicReference<>();
final CountDownLatch latch = new CountDownLatch(1);
final AtomicBoolean done = new AtomicBoolean();
final CountDownLatch serverLatch = new CountDownLatch(1);

Executors.newSingleThreadExecutor().execute(() -> {
this.executor.execute(() -> {
try {
ServerSocket server = ServerSocketFactory.getDefault().createServerSocket(0);
serverSocket.set(server);
Expand Down Expand Up @@ -667,11 +670,11 @@ private void testGWPropagatesSocketCloseGuts(final int port, AbstractClientConne
final ServerSocket server) throws Exception {
final CountDownLatch latch = new CountDownLatch(1);
final AtomicBoolean done = new AtomicBoolean();
final AtomicReference<String> lastReceived = new AtomicReference<String>();
final AtomicReference<String> lastReceived = new AtomicReference<>();
final CountDownLatch serverLatch = new CountDownLatch(1);

Executors.newSingleThreadExecutor().execute(() -> {
List<Socket> sockets = new ArrayList<Socket>();
this.executor.execute(() -> {
List<Socket> sockets = new ArrayList<>();
try {
latch.countDown();
while (!done.get()) {
Expand Down Expand Up @@ -793,8 +796,8 @@ private void testGWPropagatesSocketTimeoutGuts(final int port, AbstractClientCon
final CountDownLatch latch = new CountDownLatch(1);
final AtomicBoolean done = new AtomicBoolean();

Executors.newSingleThreadExecutor().execute(() -> {
List<Socket> sockets = new ArrayList<Socket>();
this.executor.execute(() -> {
List<Socket> sockets = new ArrayList<>();
try {
latch.countDown();
while (!done.get()) {
Expand Down
Loading