Skip to content
This repository was archived by the owner on Feb 26, 2024. It is now read-only.

Commit b70ecb3

Browse files
committed
chore: SSH host key checks
1 parent 4f9f0fb commit b70ecb3

File tree

1 file changed

+96
-13
lines changed

1 file changed

+96
-13
lines changed

system/ssh.go

Lines changed: 96 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,17 @@ import (
44
"bytes"
55
"crypto/rand"
66
"crypto/rsa"
7+
"errors"
78
"fmt"
8-
"os/exec"
9+
"net"
10+
"os"
11+
"path"
912
"strconv"
1013
"strings"
1114

1215
log "github.com/sirupsen/logrus"
16+
"golang.org/x/crypto/ssh"
17+
kh "golang.org/x/crypto/ssh/knownhosts"
1318
)
1419

1520
func GenerateSSHKeyPair() (*rsa.PrivateKey, error) {
@@ -64,6 +69,69 @@ type RemoteRunOpts struct {
6469
User string
6570
}
6671

72+
func checkKnownHosts() (ssh.HostKeyCallback, error) {
73+
f, fErr := os.OpenFile(getHostKeyPath(), os.O_CREATE, 0600)
74+
if fErr != nil {
75+
log.Fatal(fErr)
76+
}
77+
_ = f.Close()
78+
return kh.New(getHostKeyPath())
79+
}
80+
81+
func getClient(user string) (*ssh.Client, error) {
82+
var authMethod []ssh.AuthMethod
83+
if user == "stackhead" {
84+
authMethod = append(authMethod, publicKey(Context.Authentication.GetPrivateKeyPath()))
85+
} else {
86+
authMethod = append(authMethod, publicKey(path.Join(os.Getenv("HOME"), ".ssh", "id_rsa")))
87+
}
88+
89+
config := &ssh.ClientConfig{
90+
User: user,
91+
Auth: authMethod,
92+
HostKeyCallback: ssh.HostKeyCallback(func(host string, remote net.Addr, pubKey ssh.PublicKey) error {
93+
knownHosts, err := checkKnownHosts()
94+
if err != nil {
95+
return err
96+
}
97+
var keyErr *kh.KeyError
98+
hErr := knownHosts(host, remote, pubKey)
99+
if errors.As(hErr, &keyErr) && len(keyErr.Want) > 0 {
100+
// Reference: https://www.godoc.org/golang.org/x/crypto/ssh/knownhosts#KeyError
101+
// if keyErr.Want slice is empty then host is unknown, if keyErr.Want is not empty
102+
// and if host is known then there is key mismatch the connection is then rejected.
103+
//log.Printf("WARNING: The received key is not a key of %s, either a MiTM attack or %s has reconfigured the host pub key.", host, host)
104+
return keyErr
105+
} else if errors.As(hErr, &keyErr) && len(keyErr.Want) == 0 {
106+
// host key not found in known_hosts then give a warning and continue to connect.
107+
//log.Printf("WARNING: %s is not trusted, adding key to known_hosts file.", host)
108+
return addHostKey(remote, pubKey)
109+
}
110+
//log.Printf("Pub key exists for %s.", host)
111+
return nil
112+
}),
113+
}
114+
return ssh.Dial("tcp", fmt.Sprintf("%s:%s", Context.TargetHost, "22"), config)
115+
}
116+
117+
func getHostKeyPath() string {
118+
return path.Join(os.Getenv("HOME"), ".ssh", "stackhead_known_hosts")
119+
}
120+
121+
func addHostKey(remote net.Addr, pubKey ssh.PublicKey) error {
122+
// add host key if host is not found in known_hosts, error object is return, if nil then connection proceeds,
123+
// if not nil then connection stops.
124+
f, fErr := os.OpenFile(getHostKeyPath(), os.O_APPEND|os.O_WRONLY, 0600)
125+
if fErr != nil {
126+
return fErr
127+
}
128+
defer f.Close()
129+
130+
knownHosts := kh.Normalize(remote.String())
131+
_, fileErr := f.WriteString(kh.Line([]string{knownHosts}, pubKey))
132+
return fileErr
133+
}
134+
67135
func RemoteRun(cmd string, opts RemoteRunOpts) (bytes.Buffer, bytes.Buffer, error) {
68136
user := getRemoteUser()
69137
if opts.User != "" {
@@ -81,13 +149,6 @@ func RemoteRun(cmd string, opts RemoteRunOpts) (bytes.Buffer, bytes.Buffer, erro
81149
}
82150
}
83151

84-
var cmdArgs []string
85-
if user == "stackhead" {
86-
cmdArgs = []string{"-i", Context.Authentication.GetPrivateKeyPath()}
87-
}
88-
89-
cmdArgs = append(cmdArgs, fmt.Sprintf("%s@%s", user, Context.TargetHost))
90-
91152
if opts.Sudo {
92153
remoteCmd = "sudo " + remoteCmd
93154
}
@@ -99,15 +160,37 @@ func RemoteRun(cmd string, opts RemoteRunOpts) (bytes.Buffer, bytes.Buffer, erro
99160
if opts.WorkingDir != "" {
100161
remoteCmd = "cd " + opts.WorkingDir + "; " + remoteCmd
101162
}
102-
cmdArgs = append(cmdArgs, remoteCmd)
103-
command := exec.Command("ssh", cmdArgs...)
163+
104164
var out, outErr bytes.Buffer
105-
command.Stdout = &out
106-
command.Stderr = &outErr
107-
err := command.Run()
165+
client, err := getClient(user)
166+
if err != nil {
167+
return out, outErr, err
168+
}
169+
defer client.Close()
170+
ss, err := client.NewSession()
171+
if err != nil {
172+
log.Fatal("unable to create SSH session: ", err)
173+
}
174+
defer ss.Close()
175+
176+
ss.Stdout = &out
177+
ss.Stderr = &outErr
178+
err = ss.Run(remoteCmd)
108179
return out, outErr, err
109180
}
110181

182+
func publicKey(path string) ssh.AuthMethod {
183+
key, err := os.ReadFile(path)
184+
if err != nil {
185+
panic(err)
186+
}
187+
signer, err := ssh.ParsePrivateKey(key)
188+
if err != nil {
189+
panic(err)
190+
}
191+
return ssh.PublicKeys(signer)
192+
}
193+
111194
func SimpleRemoteRun(cmd string, opts RemoteRunOpts) (string, error) {
112195
stdout, stderr, err := RemoteRun(cmd, opts)
113196
if err != nil {

0 commit comments

Comments
 (0)