diff --git a/client/llb/source.go b/client/llb/source.go index e4eb2a7c756e..0e74c02f6dba 100644 --- a/client/llb/source.go +++ b/client/llb/source.go @@ -11,6 +11,7 @@ import ( "github.com/docker/distribution/reference" "github.com/moby/buildkit/solver/pb" "github.com/moby/buildkit/util/apicaps" + "github.com/moby/buildkit/util/sshutil" digest "github.com/opencontainers/go-digest" "github.com/pkg/errors" ) @@ -200,6 +201,7 @@ type ImageInfo struct { func Git(remote, ref string, opts ...GitOption) State { url := "" isSSH := true + var sshHost string for _, prefix := range []string{ "http://", "https://", @@ -219,8 +221,12 @@ func Git(remote, ref string, opts ...GitOption) State { //sshUser = parts[0] remote = parts[1] } - // keep remote consistent with http(s) version - remote = strings.Replace(remote, ":", "/", 1) + parts = strings.SplitN(remote, ":", 2) + if len(parts) == 2 { + sshHost = parts[0] + // keep remote consistent with http(s) version + remote = parts[0] + "/" + parts[1] + } } id := remote @@ -257,12 +263,23 @@ func Git(remote, ref string, opts ...GitOption) State { addCap(&gi.Constraints, pb.CapSourceGitHTTPAuth) } } - if gi.KnownSSHHosts != "" { - attrs[pb.AttrKnownSSHHosts] = gi.KnownSSHHosts + if isSSH { + if gi.KnownSSHHosts != "" { + attrs[pb.AttrKnownSSHHosts] = gi.KnownSSHHosts + } else if sshHost != "" { + keyscan, err := sshutil.SSHKeyScan(sshHost) + if err == nil { + // best effort + attrs[pb.AttrKnownSSHHosts] = keyscan + } + } addCap(&gi.Constraints, pb.CapSourceGitKnownSSHHosts) - } - if gi.MountSSHSock != "" { - attrs[pb.AttrMountSSHSock] = gi.MountSSHSock + + if gi.MountSSHSock == "" { + attrs[pb.AttrMountSSHSock] = "default" + } else { + attrs[pb.AttrMountSSHSock] = gi.MountSSHSock + } addCap(&gi.Constraints, pb.CapSourceGitMountSSHSock) } diff --git a/util/sshutil/keyscan.go b/util/sshutil/keyscan.go new file mode 100644 index 000000000000..2c219d95b978 --- /dev/null +++ b/util/sshutil/keyscan.go @@ -0,0 +1,47 @@ +package sshutil + +import ( + "fmt" + "net" + "strings" + + "golang.org/x/crypto/ssh" +) + +const defaultPort = 22 + +var ErrMalformedServer = fmt.Errorf("invalid server, must be of the form hostname, or hostname:port") + +var errCallbackDone = fmt.Errorf("callback failed on purpose") + +// SshKeyScan scans a ssh server for the hostkey; server should be in the form hostname, or hostname:port +func SSHKeyScan(server string) (string, error) { + var key string + KeyScanCallback := func(hostname string, remote net.Addr, pubKey ssh.PublicKey) error { + key = strings.TrimSpace(fmt.Sprintf("%s %s", hostname[:len(hostname)-3], string(ssh.MarshalAuthorizedKey(pubKey)))) + return errCallbackDone + } + config := &ssh.ClientConfig{ + HostKeyCallback: KeyScanCallback, + } + + var serverAndPort string + parts := strings.Split(server, ":") + if len(parts) == 1 { + serverAndPort = fmt.Sprintf("%s:%d", server, defaultPort) + } else if len(parts) == 2 { + serverAndPort = server + } else { + return "", ErrMalformedServer + } + + conn, err := ssh.Dial("tcp", serverAndPort, config) + if key != "" { + // as long as we get the key, the function worked + err = nil + } + if conn != nil { + conn.Close() + } + return key, err +}