Skip to content

Commit

Permalink
Implemented diffie-hellman-group-exchange Kex methods (Fixes #167)
Browse files Browse the repository at this point in the history
  • Loading branch information
hierynomus committed Oct 29, 2015
1 parent e24ed6e commit 47df71c
Show file tree
Hide file tree
Showing 10 changed files with 300 additions and 20 deletions.
6 changes: 4 additions & 2 deletions src/main/java/net/schmizz/sshj/DefaultConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
import net.schmizz.sshj.transport.compression.NoneCompression;
import net.schmizz.sshj.transport.kex.DHG1;
import net.schmizz.sshj.transport.kex.DHG14;
import net.schmizz.sshj.transport.kex.DHGexSHA1;
import net.schmizz.sshj.transport.kex.DHGexSHA256;
import net.schmizz.sshj.transport.mac.HMACMD5;
import net.schmizz.sshj.transport.mac.HMACMD596;
import net.schmizz.sshj.transport.mac.HMACSHA1;
Expand Down Expand Up @@ -98,9 +100,9 @@ public DefaultConfig() {

protected void initKeyExchangeFactories(boolean bouncyCastleRegistered) {
if (bouncyCastleRegistered)
setKeyExchangeFactories(new DHG14.Factory(), new DHG1.Factory());
setKeyExchangeFactories(new DHG14.Factory(), new DHG1.Factory(), new DHGexSHA1.Factory(), new DHGexSHA256.Factory());
else
setKeyExchangeFactories(new DHG1.Factory());
setKeyExchangeFactories(new DHG1.Factory(), new DHGexSHA1.Factory());
}

protected void initRandomFactory(boolean bouncyCastleRegistered) {
Expand Down
26 changes: 26 additions & 0 deletions src/main/java/net/schmizz/sshj/transport/digest/SHA256.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package net.schmizz.sshj.transport.digest;

/** SHA256 Digest. */
public class SHA256 extends BaseDigest {

/** Named factory for SHA256 digest */
public static class Factory
implements net.schmizz.sshj.common.Factory.Named<Digest> {

@Override
public Digest create() {
return new SHA256();
}

@Override
public String getName() {
return "sha256";
}
}

/** Create a new instance of a SHA256 digest */
public SHA256() {
super("SHA-256", 32);
}

}
21 changes: 3 additions & 18 deletions src/main/java/net/schmizz/sshj/transport/kex/AbstractDHG.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,14 @@
* Base class for DHG key exchange algorithms. Implementations will only have to configure the required data on the
* {@link DH} class in the
*/
public abstract class AbstractDHG
public abstract class AbstractDHG extends KeyExchangeBase
implements KeyExchange {

private final Logger log = LoggerFactory.getLogger(getClass());

private Transport trans;

private final Digest sha1 = new SHA1();
private final DH dh = new DH();

private String V_S;
private String V_C;
private byte[] I_S;
private byte[] I_C;

private byte[] H;
private PublicKey hostKey;

Expand All @@ -79,11 +72,7 @@ public PublicKey getHostKey() {
@Override
public void init(Transport trans, String V_S, String V_C, byte[] I_S, byte[] I_C)
throws GeneralSecurityException, TransportException {
this.trans = trans;
this.V_S = V_S;
this.V_C = V_C;
this.I_S = Arrays.copyOf(I_S, I_S.length);
this.I_C = Arrays.copyOf(I_C, I_C.length);
super.init(trans, V_S, V_C, I_S, I_C);
sha1.init();
initDH(dh);

Expand Down Expand Up @@ -112,11 +101,7 @@ public boolean next(Message msg, SSHPacket packet)

dh.computeK(f);

final Buffer.PlainBuffer buf = new Buffer.PlainBuffer()
.putString(V_C)
.putString(V_S)
.putString(I_C)
.putString(I_S)
final Buffer.PlainBuffer buf = initializedBuffer()
.putString(K_S)
.putMPInt(dh.getE())
.putMPInt(f)
Expand Down
124 changes: 124 additions & 0 deletions src/main/java/net/schmizz/sshj/transport/kex/AbstractDHGex.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package net.schmizz.sshj.transport.kex;

import net.schmizz.sshj.common.*;
import net.schmizz.sshj.signature.Signature;
import net.schmizz.sshj.transport.Transport;
import net.schmizz.sshj.transport.TransportException;
import net.schmizz.sshj.transport.digest.Digest;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.math.BigInteger;
import java.security.GeneralSecurityException;
import java.security.PublicKey;
import java.util.Arrays;

public abstract class AbstractDHGex extends KeyExchangeBase {
private final Logger log = LoggerFactory.getLogger(getClass());

private Digest digest;

private int minBits = 1024;
private int maxBits = 8192;
private int preferredBits = 2048;

private DH dh;
private PublicKey hostKey;
private byte[] H;

public AbstractDHGex(Digest digest) {
this.digest = digest;
}

@Override
public void init(Transport trans, String V_S, String V_C, byte[] I_S, byte[] I_C) throws GeneralSecurityException, TransportException {
super.init(trans, V_S, V_C, I_S, I_C);
dh = new DH();
digest.init();

log.debug("Sending {}", Message.KEX_DH_GEX_REQUEST);
trans.write(new SSHPacket(Message.KEX_DH_GEX_REQUEST).putUInt32(minBits).putUInt32(preferredBits).putUInt32(maxBits));
}

@Override
public byte[] getH() {
return Arrays.copyOf(H, H.length);
}

@Override
public BigInteger getK() {
return dh.getK();
}

@Override
public Digest getHash() {
return digest;
}

@Override
public PublicKey getHostKey() {
return hostKey;
}

@Override
public boolean next(Message msg, SSHPacket buffer) throws GeneralSecurityException, TransportException {
log.debug("Got message {}", msg);
try {
switch (msg) {
case KEXDH_31:
return parseGexGroup(buffer);
case KEX_DH_GEX_REPLY:
return parseGexReply(buffer);
}
} catch (Buffer.BufferException be) {
throw new TransportException(be);
}
throw new TransportException("Unexpected message " + msg);
}

private boolean parseGexReply(SSHPacket buffer) throws Buffer.BufferException, GeneralSecurityException, TransportException {
byte[] K_S = buffer.readBytes();
BigInteger f = buffer.readMPInt();
byte[] sig = buffer.readBytes();
hostKey = new Buffer.PlainBuffer(K_S).readPublicKey();

dh.computeK(f);
BigInteger k = dh.getK();

final Buffer.PlainBuffer buf = initializedBuffer()
.putString(K_S)
.putUInt32(minBits)
.putUInt32(preferredBits)
.putUInt32(maxBits)
.putMPInt(dh.getP())
.putMPInt(dh.getG())
.putMPInt(dh.getE())
.putMPInt(f)
.putMPInt(k);
digest.update(buf.array(), buf.rpos(), buf.available());
H = digest.digest();
Signature signature = Factory.Named.Util.create(trans.getConfig().getSignatureFactories(),
KeyType.fromKey(hostKey).toString());
signature.init(hostKey, null);
signature.update(H, 0, H.length);
if (!signature.verify(sig))
throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED,
"KeyExchange signature verification failed");
return true;

}

private boolean parseGexGroup(SSHPacket buffer) throws Buffer.BufferException, GeneralSecurityException, TransportException {
BigInteger p = buffer.readMPInt();
BigInteger g = buffer.readMPInt();
int bitLength = p.bitLength();
if (bitLength < minBits || bitLength > maxBits) {
throw new GeneralSecurityException("Server generated gex p is out of range (" + bitLength + " bits)");
}
log.debug("Received server p bitlength {}", bitLength);
dh.init(p, g);
log.debug("Sending {}", Message.KEX_DH_GEX_INIT);
trans.write(new SSHPacket(Message.KEX_DH_GEX_INIT).putMPInt(dh.getE()));
return false;
}
}
7 changes: 7 additions & 0 deletions src/main/java/net/schmizz/sshj/transport/kex/DH.java
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,11 @@ public BigInteger getK() {
return K;
}

public BigInteger getP() {
return p;
}

public BigInteger getG() {
return g;
}
}
25 changes: 25 additions & 0 deletions src/main/java/net/schmizz/sshj/transport/kex/DHGexSHA1.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package net.schmizz.sshj.transport.kex;

import net.schmizz.sshj.transport.digest.SHA1;

public class DHGexSHA1 extends AbstractDHGex {

/** Named factory for DHGexSHA1 key exchange */
public static class Factory
implements net.schmizz.sshj.common.Factory.Named<KeyExchange> {

@Override
public KeyExchange create() {
return new DHGexSHA1();
}

@Override
public String getName() {
return "diffie-hellman-group-exchange-sha1";
}
}

public DHGexSHA1() {
super(new SHA1());
}
}
25 changes: 25 additions & 0 deletions src/main/java/net/schmizz/sshj/transport/kex/DHGexSHA256.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package net.schmizz.sshj.transport.kex;

import net.schmizz.sshj.transport.digest.SHA256;

public class DHGexSHA256 extends AbstractDHGex {

/** Named factory for DHGexSHA256 key exchange */
public static class Factory
implements net.schmizz.sshj.common.Factory.Named<KeyExchange> {

@Override
public KeyExchange create() {
return new DHGexSHA256();
}

@Override
public String getName() {
return "diffie-hellman-group-exchange-sha256";
}
}

public DHGexSHA256() {
super(new SHA256());
}
}
37 changes: 37 additions & 0 deletions src/main/java/net/schmizz/sshj/transport/kex/KeyExchangeBase.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package net.schmizz.sshj.transport.kex;

import net.schmizz.sshj.common.Buffer;
import net.schmizz.sshj.transport.Transport;
import net.schmizz.sshj.transport.TransportException;

import java.security.GeneralSecurityException;
import java.util.Arrays;

/**
* Created by ajvanerp on 29/10/15.
*/
public abstract class KeyExchangeBase implements KeyExchange {
protected Transport trans;

private String V_S;
private String V_C;
private byte[] I_S;
private byte[] I_C;

@Override
public void init(Transport trans, String V_S, String V_C, byte[] I_S, byte[] I_C) throws GeneralSecurityException, TransportException {
this.trans = trans;
this.V_S = V_S;
this.V_C = V_C;
this.I_S = Arrays.copyOf(I_S, I_S.length);
this.I_C = Arrays.copyOf(I_C, I_C.length);
}

protected Buffer.PlainBuffer initializedBuffer() {
return new Buffer.PlainBuffer()
.putString(V_C)
.putString(V_S)
.putString(I_C)
.putString(I_S);
}
}
4 changes: 4 additions & 0 deletions src/test/java/com/hierynomus/sshj/test/SshFixture.java
Original file line number Diff line number Diff line change
Expand Up @@ -158,4 +158,8 @@ public void stopServer() {
}
}
}

public SshServer getServer() {
return server;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package com.hierynomus.sshj.transport.kex;

import com.hierynomus.sshj.test.SshFixture;
import net.schmizz.sshj.SSHClient;
import org.apache.sshd.common.KeyExchange;
import org.apache.sshd.common.NamedFactory;
import org.apache.sshd.server.kex.DHGEX;
import org.apache.sshd.server.kex.DHGEX256;
import org.junit.After;
import org.junit.Rule;
import org.junit.Test;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;

import static org.hamcrest.MatcherAssert.assertThat;

public class DiffieHellmanGroupExchangeTest {
@Rule
public SshFixture fixture = new SshFixture(false);

@After
public void stopServer() {
fixture.stopServer();
}

@Test
public void shouldKexWithGroupExchangeSha1() throws IOException {
setupAndCheckKex(new DHGEX.Factory());
}

@Test
public void shouldKexWithGroupExchangeSha256() throws IOException {
setupAndCheckKex(new DHGEX256.Factory());
}

private void setupAndCheckKex(NamedFactory<KeyExchange> factory) throws IOException {
fixture.getServer().setKeyExchangeFactories(Collections.singletonList(factory));
fixture.start();
SSHClient sshClient = fixture.setupConnectedDefaultClient();
assertThat("should be connected", sshClient.isConnected());
sshClient.disconnect();
}
}

0 comments on commit 47df71c

Please sign in to comment.