Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/main/java/net/schmizz/sshj/transport/Transport.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import java.io.InputStream;
import java.io.OutputStream;
import java.util.List;
import java.util.concurrent.TimeUnit;

/** Transport layer of the SSH protocol. */
Expand Down Expand Up @@ -224,5 +225,5 @@ long write(SSHPacket payload)
void die(Exception e);

KeyAlgorithm getHostKeyAlgorithm();
KeyAlgorithm getClientKeyAlgorithm(KeyType keyType) throws TransportException;
List<KeyAlgorithm> getClientKeyAlgorithms(KeyType keyType) throws TransportException;
}
12 changes: 8 additions & 4 deletions src/main/java/net/schmizz/sshj/transport/TransportImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetSocketAddress;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.ReentrantLock;
Expand Down Expand Up @@ -637,15 +638,18 @@ public KeyAlgorithm getHostKeyAlgorithm() {
}

@Override
public KeyAlgorithm getClientKeyAlgorithm(KeyType keyType) throws TransportException {
public List<KeyAlgorithm> getClientKeyAlgorithms(KeyType keyType) throws TransportException {
List<Factory.Named<KeyAlgorithm>> factories = getConfig().getKeyAlgorithms();
List<KeyAlgorithm> available = new ArrayList<>();
if (factories != null)
for (Factory.Named<KeyAlgorithm> f : factories)
if (
f instanceof KeyAlgorithms.Factory && ((KeyAlgorithms.Factory) f).getKeyType().equals(keyType)
|| !(f instanceof KeyAlgorithms.Factory) && f.getName().equals(keyType.toString())
|| !(f instanceof KeyAlgorithms.Factory) && f.getName().equals(keyType.toString())
)
return f.create();
throw new TransportException("Cannot find an available KeyAlgorithm for type " + keyType);
available.add(f.create());
if (available.isEmpty())
throw new TransportException("Cannot find an available KeyAlgorithm for type " + keyType);
return available;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,36 @@
import java.io.IOException;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.util.LinkedList;
import java.util.Queue;

public abstract class KeyedAuthMethod
extends AbstractAuthMethod {

protected final KeyProvider kProv;
private Queue<KeyAlgorithm> available;

public KeyedAuthMethod(String name, KeyProvider kProv) {
super(name);
this.kProv = kProv;
}

private KeyAlgorithm getPublicKeyAlgorithm(KeyType keyType) throws TransportException {
if (available == null) {
available = new LinkedList<>(params.getTransport().getClientKeyAlgorithms(keyType));
}
return available.peek();
}

@Override
public boolean shouldRetry() {
if (available != null) {
available.poll();
return !available.isEmpty();
}
return false;
}

protected SSHPacket putPubKey(SSHPacket reqBuf)
throws UserAuthException {
PublicKey key;
Expand All @@ -50,7 +69,7 @@ protected SSHPacket putPubKey(SSHPacket reqBuf)
// public key as 2 strings: [ key type | key blob ]
KeyType keyType = KeyType.fromKey(key);
try {
KeyAlgorithm ka = params.getTransport().getClientKeyAlgorithm(keyType);
KeyAlgorithm ka = getPublicKeyAlgorithm(keyType);
if (ka != null) {
reqBuf.putString(ka.getKeyAlgorithm())
.putString(new Buffer.PlainBuffer().putPublicKey(key).getCompactData());
Expand All @@ -74,7 +93,7 @@ protected SSHPacket putSig(SSHPacket reqBuf)
final KeyType kt = KeyType.fromKey(key);
Signature signature;
try {
signature = params.getTransport().getClientKeyAlgorithm(kt).newSignature();
signature = getPublicKeyAlgorithm(kt).newSignature();
} catch (TransportException e) {
throw new UserAuthException("No KeyAlgorithm configured for key " + kt);
}
Expand Down