diff --git a/src/itest/java/com/hierynomus/sshj/transport/kex/StrictKeyExchangeTest.java b/src/itest/java/com/hierynomus/sshj/transport/kex/StrictKeyExchangeTest.java new file mode 100644 index 00000000..2abe71a7 --- /dev/null +++ b/src/itest/java/com/hierynomus/sshj/transport/kex/StrictKeyExchangeTest.java @@ -0,0 +1,111 @@ +/* + * Copyright (C)2009 - SSHJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.hierynomus.sshj.transport.kex; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +import ch.qos.logback.classic.Logger; +import ch.qos.logback.classic.spi.ILoggingEvent; +import ch.qos.logback.core.read.ListAppender; +import com.hierynomus.sshj.SshdContainer; +import net.schmizz.sshj.SSHClient; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.slf4j.LoggerFactory; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertTrue; + +@Testcontainers +class StrictKeyExchangeTest { + + @Container + private static final SshdContainer sshd = new SshdContainer(); + + private final List watchedLoggers = new ArrayList<>(); + private final ListAppender logWatcher = new ListAppender<>(); + + @BeforeEach + void setUpLogWatcher() { + logWatcher.start(); + setUpLogger("net.schmizz.sshj.transport.Decoder"); + setUpLogger("net.schmizz.sshj.transport.Encoder"); + setUpLogger("net.schmizz.sshj.transport.KeyExchanger"); + } + + @AfterEach + void tearDown() { + watchedLoggers.forEach(Logger::detachAndStopAllAppenders); + } + + private void setUpLogger(String className) { + Logger logger = ((Logger) LoggerFactory.getLogger(className)); + logger.addAppender(logWatcher); + watchedLoggers.add(logger); + } + + @Test + void strictKeyExchange() throws Throwable { + try (SSHClient client = sshd.getConnectedClient()) { + client.authPublickey("sshj", "src/itest/resources/keyfiles/id_rsa_opensshv1"); + assertTrue(client.isAuthenticated()); + } + List keyExchangerLogs = getLogs("KeyExchanger"); + assertThat(keyExchangerLogs).containsSequence( + "Initiating key exchange", + "Sending SSH_MSG_KEXINIT", + "Received SSH_MSG_KEXINIT", + "Enabling strict key exchange extension" + ); + List decoderLogs = getLogs("Decoder").stream() + .map(log -> log.split(":")[0]) + .collect(Collectors.toList()); + assertThat(decoderLogs).containsExactly( + "Received packet #0", + "Received packet #1", + "Received packet #2", + "Received packet #0", + "Received packet #1", + "Received packet #2", + "Received packet #3" + ); + List encoderLogs = getLogs("Encoder").stream() + .map(log -> log.split(":")[0]) + .collect(Collectors.toList()); + assertThat(encoderLogs).containsExactly( + "Encoding packet #0", + "Encoding packet #1", + "Encoding packet #2", + "Encoding packet #0", + "Encoding packet #1", + "Encoding packet #2", + "Encoding packet #3" + ); + } + + private List getLogs(String className) { + return logWatcher.list.stream() + .filter(event -> event.getLoggerName().endsWith(className)) + .map(ILoggingEvent::getFormattedMessage) + .collect(Collectors.toList()); + } + +} diff --git a/src/main/java/net/schmizz/sshj/transport/Converter.java b/src/main/java/net/schmizz/sshj/transport/Converter.java index 9d532f96..6f3431c3 100644 --- a/src/main/java/net/schmizz/sshj/transport/Converter.java +++ b/src/main/java/net/schmizz/sshj/transport/Converter.java @@ -51,6 +51,14 @@ long getSequenceNumber() { return seq; } + void resetSequenceNumber() { + seq = -1; + } + + boolean isSequenceNumberAtMax() { + return seq == 0xffffffffL; + } + void setAlgorithms(Cipher cipher, MAC mac, Compression compression) { this.cipher = cipher; this.mac = mac; diff --git a/src/main/java/net/schmizz/sshj/transport/KeyExchanger.java b/src/main/java/net/schmizz/sshj/transport/KeyExchanger.java index 6705519f..8df8c221 100644 --- a/src/main/java/net/schmizz/sshj/transport/KeyExchanger.java +++ b/src/main/java/net/schmizz/sshj/transport/KeyExchanger.java @@ -60,6 +60,10 @@ private static enum Expected { private final AtomicBoolean kexOngoing = new AtomicBoolean(); + private final AtomicBoolean initialKex = new AtomicBoolean(true); + + private final AtomicBoolean strictKex = new AtomicBoolean(); + /** What we are expecting from the next packet */ private Expected expected = Expected.KEXINIT; @@ -123,6 +127,14 @@ boolean isKexOngoing() { return kexOngoing.get(); } + boolean isStrictKex() { + return strictKex.get(); + } + + boolean isInitialKex() { + return initialKex.get(); + } + /** * Starts key exchange by sending a {@code SSH_MSG_KEXINIT} packet. Key exchange needs to be done once mandatorily * after initializing the {@link Transport} for it to be usable and may be initiated at any later point e.g. if @@ -171,7 +183,7 @@ private void sendKexInit() throws TransportException { log.debug("Sending SSH_MSG_KEXINIT"); List knownHostAlgs = findKnownHostAlgs(transport.getRemoteHost(), transport.getRemotePort()); - clientProposal = new Proposal(transport.getConfig(), knownHostAlgs); + clientProposal = new Proposal(transport.getConfig(), knownHostAlgs, initialKex.get()); transport.write(clientProposal.getPacket()); kexInitSent.set(); } @@ -190,6 +202,9 @@ private void sendNewKeys() throws TransportException { log.debug("Sending SSH_MSG_NEWKEYS"); transport.write(new SSHPacket(Message.NEWKEYS)); + if (strictKex.get()) { + transport.getEncoder().resetSequenceNumber(); + } } /** @@ -222,6 +237,10 @@ private synchronized void verifyHost(PublicKey key) private void setKexDone() { kexOngoing.set(false); + initialKex.set(false); + if (strictKex.get()) { + transport.getDecoder().resetSequenceNumber(); + } kexInitSent.clear(); done.set(); } @@ -230,6 +249,7 @@ private void gotKexInit(SSHPacket buf) throws TransportException { buf.rpos(buf.rpos() - 1); final Proposal serverProposal = new Proposal(buf); + gotStrictKexInfo(serverProposal); negotiatedAlgs = clientProposal.negotiate(serverProposal); log.debug("Negotiated algorithms: {}", negotiatedAlgs); for(AlgorithmsVerifier v: algorithmVerifiers) { @@ -253,6 +273,18 @@ private void gotKexInit(SSHPacket buf) } } + private void gotStrictKexInfo(Proposal serverProposal) throws TransportException { + if (initialKex.get() && serverProposal.isStrictKeyExchangeSupportedByServer()) { + strictKex.set(true); + log.debug("Enabling strict key exchange extension"); + if (transport.getDecoder().getSequenceNumber() != 0) { + throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED, + "SSH_MSG_KEXINIT was not first package during strict key exchange" + ); + } + } + } + /** * Private method used while putting new keys into use that will resize the key used to initialize the cipher to the * needed length. diff --git a/src/main/java/net/schmizz/sshj/transport/Proposal.java b/src/main/java/net/schmizz/sshj/transport/Proposal.java index 5f5f8a1f..3a4102dd 100644 --- a/src/main/java/net/schmizz/sshj/transport/Proposal.java +++ b/src/main/java/net/schmizz/sshj/transport/Proposal.java @@ -37,8 +37,11 @@ class Proposal { private final List s2cComp; private final SSHPacket packet; - public Proposal(Config config, List knownHostAlgs) { + public Proposal(Config config, List knownHostAlgs, boolean initialKex) { kex = Factory.Named.Util.getNames(config.getKeyExchangeFactories()); + if (initialKex) { + kex.add("kex-strict-c-v00@openssh.com"); + } sig = filterKnownHostKeyAlgorithms(Factory.Named.Util.getNames(config.getKeyAlgorithms()), knownHostAlgs); c2sCipher = s2cCipher = Factory.Named.Util.getNames(config.getCipherFactories()); c2sMAC = s2cMAC = Factory.Named.Util.getNames(config.getMACFactories()); @@ -91,6 +94,10 @@ public List getKeyExchangeAlgorithms() { return kex; } + public boolean isStrictKeyExchangeSupportedByServer() { + return kex.contains("kex-strict-s-v00@openssh.com"); + } + public List getHostKeyAlgorithms() { return sig; } diff --git a/src/main/java/net/schmizz/sshj/transport/TransportImpl.java b/src/main/java/net/schmizz/sshj/transport/TransportImpl.java index edff191c..53d8b3e2 100644 --- a/src/main/java/net/schmizz/sshj/transport/TransportImpl.java +++ b/src/main/java/net/schmizz/sshj/transport/TransportImpl.java @@ -436,7 +436,7 @@ public long write(SSHPacket payload) assert m != Message.KEXINIT; kexer.waitForDone(); } - } else if (encoder.getSequenceNumber() == 0) // We get here every 2**32th packet + } else if (encoder.isSequenceNumberAtMax()) // We get here every 2**32th packet kexer.startKex(true); final long seq = encoder.encode(payload); @@ -489,9 +489,20 @@ public void handle(Message msg, SSHPacket buf) log.trace("Received packet {}", msg); + if (kexer.isInitialKex()) { + if (decoder.isSequenceNumberAtMax()) { + throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED, + "Sequence number of decoder is about to wrap during initial key exchange"); + } + if (kexer.isStrictKex() && !isKexerPacket(msg) && msg != Message.DISCONNECT) { + throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED, + "Unexpected packet type during initial strict key exchange"); + } + } + if (msg.geq(50)) { // not a transport layer packet service.handle(msg, buf); - } else if (msg.in(20, 21) || msg.in(30, 49)) { // kex packet + } else if (isKexerPacket(msg)) { kexer.handle(msg, buf); } else { switch (msg) { @@ -523,6 +534,10 @@ public void handle(Message msg, SSHPacket buf) } } + private static boolean isKexerPacket(Message msg) { + return msg.in(20, 21) || msg.in(30, 49); + } + private void gotDebug(SSHPacket buf) throws TransportException { try {