Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 47 additions & 20 deletions src/main/java/net/juniper/netconf/NetconfSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,14 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicReference;

import static java.util.Optional.ofNullable;

/**
* A {@code NetconfSession} is obtained by first building a
Expand All @@ -59,6 +66,7 @@
public class NetconfSession {

private static final org.slf4j.Logger log = org.slf4j.LoggerFactory.getLogger(NetconfSession.class);
private final ExecutorService singleThreadExecutor = Executors.newSingleThreadExecutor();

private final Channel netconfChannel;
private String serverCapability;
Expand Down Expand Up @@ -127,30 +135,49 @@ private void sendHello(String hello) throws IOException {
}

@VisibleForTesting
String getRpcReply(String rpc) throws IOException {
String getRpcReply(final String rpc) throws IOException {
// write the rpc to the device
sendRpcRequest(rpc);

final char[] buffer = new char[BUFFER_SIZE];
final StringBuilder rpcReply = new StringBuilder();
final long startTime = System.nanoTime();
final Reader in = new InputStreamReader(stdInStreamFromDevice, Charsets.UTF_8);
boolean timeoutNotExceeded = true;
int promptPosition;
while ((promptPosition = rpcReply.indexOf(NetconfConstants.DEVICE_PROMPT)) < 0 &&
(timeoutNotExceeded = (TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTime) < commandTimeout))) {
int charsRead = in.read(buffer, 0, buffer.length);
if (charsRead < 0) throw new NetconfException("Input Stream has been closed during reading.");
rpcReply.append(buffer, 0, charsRead);
}

if (!timeoutNotExceeded)
final AtomicReference<Thread> threadReference = new AtomicReference<>();
try {
return singleThreadExecutor.submit(() -> {
try {

threadReference.set(Thread.currentThread());
final char[] buffer = new char[BUFFER_SIZE];
final StringBuilder rpcReply = new StringBuilder();
final Reader in = new InputStreamReader(stdInStreamFromDevice, Charsets.UTF_8);
int promptPosition;
while ((promptPosition = rpcReply.indexOf(NetconfConstants.DEVICE_PROMPT)) < 0) {
int charsRead = in.read(buffer, 0, buffer.length);
if (charsRead < 0) throw new NetconfException("Input Stream has been closed during reading.");
rpcReply.append(buffer, 0, charsRead);
}

log.debug("Received Netconf RPC-Reply\n{}", rpcReply);
rpcReply.setLength(promptPosition);
return rpcReply.toString();

} catch (final Exception e) {
log.warn("Error reading from input stream", e);
throw e;
}
})
.get(commandTimeout, TimeUnit.MILLISECONDS);
} catch (final InterruptedException e) {
Thread.currentThread().interrupt();
throw new NetconfException("Thread interrupted whilst waiting for RPC reply", e);
} catch (final ExecutionException e) {
if(e.getCause() instanceof NetconfException) {
throw (NetconfException) e.getCause();
}
throw new NetconfException("Unexpected exception whilst waiting for RPC reply", e);
} catch (final TimeoutException e) {
// Make sure the thread isn't still running
ofNullable(threadReference.get()).ifPresent(Thread::interrupt);
throw new SocketTimeoutException("Command timeout limit was exceeded: " + commandTimeout);
// fixing the rpc reply by removing device prompt
log.debug("Received Netconf RPC-Reply\n{}", rpcReply);
rpcReply.setLength(promptPosition);

return rpcReply.toString();
}
}

private BufferedReader getRpcReplyRunning(String rpc) throws IOException {
Expand Down
72 changes: 68 additions & 4 deletions src/test/java/net/juniper/netconf/NetconfSessionTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.slf4j.Logger;
Expand All @@ -25,17 +26,17 @@
import java.net.SocketTimeoutException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.time.Duration;
import java.time.Instant;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doCallRealMethod;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class NetconfSessionTest {
Expand Down Expand Up @@ -489,4 +490,67 @@ private String createHelloMessage() {
+ " <session-id>27700</session-id>\n"
+ "</hello>";
}


@Test
@Timeout(value = 2, unit = TimeUnit.SECONDS)
void ifTheDeviceDoesNotRespondAnExceptionWillBeThrown() {
final Duration commandTimeoutDuration = Duration.ofSeconds(1);

final Instant startTime = Instant.now();
assertThatThrownBy(() -> createNetconfSession((int) commandTimeoutDuration.toMillis()))
.isInstanceOf(SocketTimeoutException.class)
.hasMessageStartingWith("Command timeout limit was exceeded");

final Duration executeRpcDuration = Duration.between(startTime, Instant.now());
// This should have taken about 1 second to time out
assertThat(executeRpcDuration)
.isGreaterThanOrEqualTo(commandTimeoutDuration);
}

@Test
@Timeout(value = 2, unit = TimeUnit.SECONDS)
void ifTheDeviceDoesNotRespondTheSessionCanStillBeUsed() throws Exception {

final Semaphore semaphore = new Semaphore(0);

final Duration commandTimeoutDuration = Duration.ofSeconds(1);

new Thread(() -> {
try {
// This is the "hello" from the device, in response to the "Hello" to the initial client ""hello"
outPipe.write(FAKE_RPC_REPLY.getBytes(StandardCharsets.UTF_8));
outPipe.write(DEVICE_PROMPT_BYTE);
outPipe.flush();

// Don't send any response until it's required
semaphore.acquire();
// Now send a second response
outPipe.write(FAKE_RPC_REPLY.getBytes(StandardCharsets.UTF_8));
outPipe.write(DEVICE_PROMPT_BYTE);
outPipe.flush();
outPipe.close();
} catch (final Exception e) {
log.error("Error in background thread", e);
}
}).start();
final NetconfSession netconfSession = createNetconfSession((int) commandTimeoutDuration.toMillis());
// We've now received a "FAKE_RPC_REPLY"

// Now send a request, but we're expecting a timeout as the device won't send it yet
final Instant startTime = Instant.now();
assertThatThrownBy(() -> netconfSession.getRpcReply("<some-command/>"))
.isInstanceOf(SocketTimeoutException.class)
.hasMessageStartingWith("Command timeout limit was exceeded");
final Duration executeRpcDuration = Duration.between(startTime, Instant.now());

// This should have taken about 1 second to time out
assertThat(executeRpcDuration)
.isGreaterThanOrEqualTo(commandTimeoutDuration);

// Try again - we should get a reply
semaphore.release(); // Ensure the device sends a response
final String rpcReply = netconfSession.getRpcReply("<some-command/>");
assertThat(rpcReply).isEqualTo(FAKE_RPC_REPLY);
}
}