@@ -4,12 +4,17 @@ import (
44 "bytes"
55 "crypto/rand"
66 "crypto/rsa"
7+ "errors"
78 "fmt"
8- "os/exec"
9+ "net"
10+ "os"
11+ "path"
912 "strconv"
1013 "strings"
1114
1215 log "github.com/sirupsen/logrus"
16+ "golang.org/x/crypto/ssh"
17+ kh "golang.org/x/crypto/ssh/knownhosts"
1318)
1419
1520func GenerateSSHKeyPair () (* rsa.PrivateKey , error ) {
@@ -64,6 +69,69 @@ type RemoteRunOpts struct {
6469 User string
6570}
6671
72+ func checkKnownHosts () (ssh.HostKeyCallback , error ) {
73+ f , fErr := os .OpenFile (getHostKeyPath (), os .O_CREATE , 0600 )
74+ if fErr != nil {
75+ log .Fatal (fErr )
76+ }
77+ _ = f .Close ()
78+ return kh .New (getHostKeyPath ())
79+ }
80+
81+ func getClient (user string ) (* ssh.Client , error ) {
82+ var authMethod []ssh.AuthMethod
83+ if user == "stackhead" {
84+ authMethod = append (authMethod , publicKey (Context .Authentication .GetPrivateKeyPath ()))
85+ } else {
86+ authMethod = append (authMethod , publicKey (path .Join (os .Getenv ("HOME" ), ".ssh" , "id_rsa" )))
87+ }
88+
89+ config := & ssh.ClientConfig {
90+ User : user ,
91+ Auth : authMethod ,
92+ HostKeyCallback : ssh .HostKeyCallback (func (host string , remote net.Addr , pubKey ssh.PublicKey ) error {
93+ knownHosts , err := checkKnownHosts ()
94+ if err != nil {
95+ return err
96+ }
97+ var keyErr * kh.KeyError
98+ hErr := knownHosts (host , remote , pubKey )
99+ if errors .As (hErr , & keyErr ) && len (keyErr .Want ) > 0 {
100+ // Reference: https://www.godoc.org/golang.org/x/crypto/ssh/knownhosts#KeyError
101+ // if keyErr.Want slice is empty then host is unknown, if keyErr.Want is not empty
102+ // and if host is known then there is key mismatch the connection is then rejected.
103+ //log.Printf("WARNING: The received key is not a key of %s, either a MiTM attack or %s has reconfigured the host pub key.", host, host)
104+ return keyErr
105+ } else if errors .As (hErr , & keyErr ) && len (keyErr .Want ) == 0 {
106+ // host key not found in known_hosts then give a warning and continue to connect.
107+ //log.Printf("WARNING: %s is not trusted, adding key to known_hosts file.", host)
108+ return addHostKey (remote , pubKey )
109+ }
110+ //log.Printf("Pub key exists for %s.", host)
111+ return nil
112+ }),
113+ }
114+ return ssh .Dial ("tcp" , fmt .Sprintf ("%s:%s" , Context .TargetHost , "22" ), config )
115+ }
116+
117+ func getHostKeyPath () string {
118+ return path .Join (os .Getenv ("HOME" ), ".ssh" , "stackhead_known_hosts" )
119+ }
120+
121+ func addHostKey (remote net.Addr , pubKey ssh.PublicKey ) error {
122+ // add host key if host is not found in known_hosts, error object is return, if nil then connection proceeds,
123+ // if not nil then connection stops.
124+ f , fErr := os .OpenFile (getHostKeyPath (), os .O_APPEND | os .O_WRONLY , 0600 )
125+ if fErr != nil {
126+ return fErr
127+ }
128+ defer f .Close ()
129+
130+ knownHosts := kh .Normalize (remote .String ())
131+ _ , fileErr := f .WriteString (kh .Line ([]string {knownHosts }, pubKey ))
132+ return fileErr
133+ }
134+
67135func RemoteRun (cmd string , opts RemoteRunOpts ) (bytes.Buffer , bytes.Buffer , error ) {
68136 user := getRemoteUser ()
69137 if opts .User != "" {
@@ -81,13 +149,6 @@ func RemoteRun(cmd string, opts RemoteRunOpts) (bytes.Buffer, bytes.Buffer, erro
81149 }
82150 }
83151
84- var cmdArgs []string
85- if user == "stackhead" {
86- cmdArgs = []string {"-i" , Context .Authentication .GetPrivateKeyPath ()}
87- }
88-
89- cmdArgs = append (cmdArgs , fmt .Sprintf ("%s@%s" , user , Context .TargetHost ))
90-
91152 if opts .Sudo {
92153 remoteCmd = "sudo " + remoteCmd
93154 }
@@ -99,15 +160,37 @@ func RemoteRun(cmd string, opts RemoteRunOpts) (bytes.Buffer, bytes.Buffer, erro
99160 if opts .WorkingDir != "" {
100161 remoteCmd = "cd " + opts .WorkingDir + "; " + remoteCmd
101162 }
102- cmdArgs = append (cmdArgs , remoteCmd )
103- command := exec .Command ("ssh" , cmdArgs ... )
163+
104164 var out , outErr bytes.Buffer
105- command .Stdout = & out
106- command .Stderr = & outErr
107- err := command .Run ()
165+ client , err := getClient (user )
166+ if err != nil {
167+ return out , outErr , err
168+ }
169+ defer client .Close ()
170+ ss , err := client .NewSession ()
171+ if err != nil {
172+ log .Fatal ("unable to create SSH session: " , err )
173+ }
174+ defer ss .Close ()
175+
176+ ss .Stdout = & out
177+ ss .Stderr = & outErr
178+ err = ss .Run (remoteCmd )
108179 return out , outErr , err
109180}
110181
182+ func publicKey (path string ) ssh.AuthMethod {
183+ key , err := os .ReadFile (path )
184+ if err != nil {
185+ panic (err )
186+ }
187+ signer , err := ssh .ParsePrivateKey (key )
188+ if err != nil {
189+ panic (err )
190+ }
191+ return ssh .PublicKeys (signer )
192+ }
193+
111194func SimpleRemoteRun (cmd string , opts RemoteRunOpts ) (string , error ) {
112195 stdout , stderr , err := RemoteRun (cmd , opts )
113196 if err != nil {
0 commit comments