Skip to content

Commit

Permalink
feat: port ssh cmd from python, refactored ssh key gen
Browse files Browse the repository at this point in the history
  • Loading branch information
justinmerrell committed Feb 6, 2024
1 parent 930fc3c commit 4c90f4b
Show file tree
Hide file tree
Showing 6 changed files with 320 additions and 120 deletions.
96 changes: 68 additions & 28 deletions api/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,21 @@ package api

import (
"encoding/json"
"errors"
"fmt"
"io"
"strings"

"golang.org/x/crypto/ssh"
)

func GetPublicSSHKeys() (keys string, err error) {
type SSHKey struct {
Name string `json:"name"`
Type string `json:"type"`
Key string `json:"key"`
Fingerprint string `json:"fingerprint"`
}

func GetPublicSSHKeys() (string, []SSHKey, error) {
input := Input{
Query: `
query myself {
Expand All @@ -19,48 +27,79 @@ func GetPublicSSHKeys() (keys string, err error) {
}
`,
}

res, err := Query(input)
if err != nil {
return "", err
return "", nil, err
}
defer res.Body.Close()

if res.StatusCode != 200 {
err = fmt.Errorf("statuscode %d", res.StatusCode)
return
return "", nil, fmt.Errorf("unexpected status code: %d", res.StatusCode)
}
defer res.Body.Close()

rawData, err := io.ReadAll(res.Body)
if err != nil {
return "", err
return "", nil, fmt.Errorf("failed to read response body: %w", err)
}
data := &UserOut{}
if err = json.Unmarshal(rawData, data); err != nil {
return "", err

var data UserOut
if err := json.Unmarshal(rawData, &data); err != nil {
return "", nil, fmt.Errorf("JSON unmarshal error: %w", err)
}

if len(data.Errors) > 0 {
err = errors.New(data.Errors[0].Message)
return "", err
return "", nil, fmt.Errorf("API error: %s", data.Errors[0].Message)
}
if data == nil || data.Data == nil || data.Data.Myself == nil {
err = fmt.Errorf("data is nil: %s", string(rawData))
return "", err

if data.Data == nil || data.Data.Myself == nil {
return "", nil, fmt.Errorf("nil data received: %s", string(rawData))
}

// Parse the public key string into a list of SSHKey structs
var keys []SSHKey
keyStrings := strings.Split(data.Data.Myself.PubKey, "\n")
for _, keyString := range keyStrings {
if keyString == "" {
continue
}

pubKey, name, _, _, err := ssh.ParseAuthorizedKey([]byte(keyString))
if err != nil {
continue // Skip keys that can't be parsed
}

keys = append(keys, SSHKey{
Name: name,
Type: pubKey.Type(),
Key: string(ssh.MarshalAuthorizedKey(pubKey)),
Fingerprint: ssh.FingerprintSHA256(pubKey),
})
}
return data.Data.Myself.PubKey, nil

return data.Data.Myself.PubKey, keys, nil
}

func AddPublicSSHKey(key []byte) error {
//pull existing pubKey
existingKeys, err := GetPublicSSHKeys()
rawKeys, existingKeys, err := GetPublicSSHKeys()
if err != nil {
return err
return fmt.Errorf("failed to get existing SSH keys: %w", err)
}

keyStr := string(key)
//check for key present
if strings.Contains(existingKeys, keyStr) {
return nil
for _, k := range existingKeys {
if strings.TrimSpace(k.Key) == strings.TrimSpace(keyStr) {
return nil
}
}
// concat key onto pubKey
newKeys := existingKeys + "\n\n" + keyStr
// set new pubKey

// Concatenate the new key onto the existing keys, separated by a newline
newKeys := strings.TrimSpace(rawKeys)
if newKeys != "" {
newKeys += "\n\n"
}
newKeys += strings.TrimSpace(keyStr)

input := Input{
Query: `
mutation Mutation($input: UpdateUserSettingsInput) {
Expand All @@ -71,9 +110,10 @@ func AddPublicSSHKey(key []byte) error {
`,
Variables: map[string]interface{}{"input": map[string]interface{}{"pubKey": newKeys}},
}
_, err = Query(input)
if err != nil {
return err

if _, err = Query(input); err != nil {
return fmt.Errorf("failed to update SSH keys: %w", err)
}

return nil
}
92 changes: 22 additions & 70 deletions cmd/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,100 +2,52 @@ package config

import (
"cli/api"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"cli/cmd/ssh"
"fmt"
"os"
"path/filepath"

"github.com/spf13/cobra"
"github.com/spf13/viper"
"golang.org/x/crypto/ssh"
)

var ConfigFile string
var apiKey string
var apiUrl string
var (
ConfigFile string
apiKey string
apiUrl string
)

var ConfigCmd = &cobra.Command{
Use: "config",
Short: "CLI Config",
Long: "RunPod CLI Config Settings",
Run: func(c *cobra.Command, args []string) {
err := viper.WriteConfig()
cobra.CheckErr(err)
fmt.Println("saved apiKey into config file: " + ConfigFile)
home, err := os.UserHomeDir()
if err := viper.WriteConfig(); err != nil {
fmt.Fprintf(os.Stderr, "Error saving config: %v\n", err)
return
}
fmt.Println("Configuration saved to file:", viper.ConfigFileUsed())

publicKey, err := ssh.GenerateSSHKeyPair("RunPod-Key-Go")
if err != nil {
fmt.Println("couldn't get user home dir path")
fmt.Fprintf(os.Stderr, "Failed to generate SSH key: %v\n", err)
return
}
sshFolderPath := filepath.Join(home, ".runpod", "ssh")
os.MkdirAll(sshFolderPath, os.ModePerm)
privateSshPath := filepath.Join(sshFolderPath, "RunPod-Key-Go")
publicSshPath := filepath.Join(sshFolderPath, "RunPod-Key-Go.pub")
publicKey, _ := os.ReadFile(publicSshPath)
if _, err := os.Stat(privateSshPath); errors.Is(err, os.ErrNotExist) {
publicKey = makeRSAKey(privateSshPath)

if err := api.AddPublicSSHKey(publicKey); err != nil {
fmt.Fprintf(os.Stderr, "Failed to add the SSH key: %v\n", err)
return
}
api.AddPublicSSHKey(publicKey)
fmt.Println("SSH key added successfully.")
},
}

func init() {
ConfigCmd.Flags().StringVar(&apiKey, "apiKey", "", "runpod api key")
ConfigCmd.MarkFlagRequired("apiKey")
ConfigCmd.Flags().StringVar(&apiKey, "apiKey", "", "RunPod API key")
viper.BindPFlag("apiKey", ConfigCmd.Flags().Lookup("apiKey")) //nolint
viper.SetDefault("apiKey", "")

ConfigCmd.Flags().StringVar(&apiUrl, "apiUrl", "", "runpod api url")
ConfigCmd.Flags().StringVar(&apiUrl, "apiUrl", "https://api.runpod.io/graphql", "RunPod API URL")
viper.BindPFlag("apiUrl", ConfigCmd.Flags().Lookup("apiUrl")) //nolint
viper.SetDefault("apiUrl", "https://api.runpod.io/graphql")
}

func makeRSAKey(filename string) []byte {
bitSize := 2048

// Generate RSA key.
key, err := rsa.GenerateKey(rand.Reader, bitSize)
if err != nil {
panic(err)
}

// Extract public component.
pub := key.PublicKey

// Encode private key to PKCS#1 ASN.1 PEM.
keyPEM := pem.EncodeToMemory(
&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
},
)

// generate and write public key
publicKey, err := ssh.NewPublicKey(&pub)
if err != nil {
fmt.Println("err in NewPublicKey")
fmt.Println(err)
}
pubBytes := ssh.MarshalAuthorizedKey(publicKey)
pubBytes = append(pubBytes, []byte(" "+filename)...)

// Write private key to file.
if err := os.WriteFile(filename, keyPEM, 0600); err != nil {
fmt.Println("err writing priv")
panic(err)
}

// Write public key to file.
if err := os.WriteFile(filename+".pub", pubBytes, 0600); err != nil {
fmt.Println("err writing pub")
panic(err)
}
fmt.Println("saved new SSH public key into", filename+".pub")
return pubBytes
ConfigCmd.MarkFlagRequired("apiKey")
}
54 changes: 32 additions & 22 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,41 +14,51 @@ import (

var version string

// rootCmd represents the base command when called without any subcommands
var RootCmd = &cobra.Command{
// Entrypoint for the CLI
var rootCmd = &cobra.Command{
Use: "runpodctl",
Aliases: []string{"runpod"},
Short: "CLI for runpod.io",
Long: "CLI tool to manage your pods for runpod.io",
}

// Execute adds all child commands to the root command and sets flags appropriately.
// This is called by main.main(). It only needs to happen once to the rootCmd.
func Execute(ver string) {
version = ver
api.Version = ver
err := RootCmd.Execute()
if err != nil {
os.Exit(1)
}
func GetRootCmd() *cobra.Command {
return rootCmd
}

func init() {
cobra.OnInitialize(initConfig)
RootCmd.AddCommand(config.ConfigCmd)
registerCommands()
}

func registerCommands() {
rootCmd.AddCommand(config.ConfigCmd)
// RootCmd.AddCommand(connectCmd)
// RootCmd.AddCommand(copyCmd)
RootCmd.AddCommand(createCmd)
RootCmd.AddCommand(getCmd)
RootCmd.AddCommand(removeCmd)
RootCmd.AddCommand(startCmd)
RootCmd.AddCommand(stopCmd)
RootCmd.AddCommand(versionCmd)
RootCmd.AddCommand(projectCmd)
RootCmd.AddCommand(updateCmd)
rootCmd.AddCommand(createCmd)
rootCmd.AddCommand(getCmd)
rootCmd.AddCommand(removeCmd)
rootCmd.AddCommand(startCmd)
rootCmd.AddCommand(stopCmd)
rootCmd.AddCommand(versionCmd)
rootCmd.AddCommand(projectCmd)
rootCmd.AddCommand(updateCmd)
rootCmd.AddCommand(sshCmd)

RootCmd.AddCommand(croc.ReceiveCmd)
RootCmd.AddCommand(croc.SendCmd)
// file transfer via croc
rootCmd.AddCommand(croc.ReceiveCmd)
rootCmd.AddCommand(croc.SendCmd)
}

// Execute adds all child commands to the root command and sets flags appropriately.
// This is called by main.main(). It only needs to happen once to the rootCmd.
func Execute(ver string) {
version = ver
api.Version = ver
if err := rootCmd.Execute(); err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1)
}
}

// initConfig reads in config file and ENV variables if set.
Expand Down
18 changes: 18 additions & 0 deletions cmd/ssh.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package cmd

import (
"cli/cmd/ssh"

"github.com/spf13/cobra"
)

var sshCmd = &cobra.Command{
Use: "ssh",
Short: "SSH keys and commands",
Long: "SSH key management and connection to pods",
}

func init() {
sshCmd.AddCommand(ssh.ListKeysCmd)
sshCmd.AddCommand(ssh.AddKeyCmd)
}
Loading

0 comments on commit 4c90f4b

Please sign in to comment.