Skip to content
4 changes: 2 additions & 2 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ func MakeCertRequest() CertRequest {
type SigningRequest struct {
signedCert ssh.Certificate
requestID string
config ssh_ca_util.SignerConfig
config ssh_ca_util.RequesterConfig
}

func MakeSigningRequest(cert ssh.Certificate, requestID string, config ssh_ca_util.SignerConfig) SigningRequest {
func MakeSigningRequest(cert ssh.Certificate, requestID string, config ssh_ca_util.RequesterConfig) SigningRequest {
var request SigningRequest
request.signedCert = cert
request.requestID = requestID
Expand Down
8 changes: 6 additions & 2 deletions get_cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func getCertFlags() []cli.Flag {
Value: configPath,
Usage: "Path to config.json",
},
cli.BoolTFlag{
cli.BoolFlag{
Name: "add-key",
Usage: "When set automatically call ssh-add",
},
Expand Down Expand Up @@ -65,7 +65,7 @@ func getCert(c *cli.Context) error {
if err != nil {
return cli.NewExitError(fmt.Sprintf("%s", err), 1)
}
if c.BoolT("add-key") {
if c.Bool("add-key") {
err = addCertToAgent(cert, sshDir)
if err != nil {
return cli.NewExitError(fmt.Sprintf("%s", err), 1)
Expand Down Expand Up @@ -94,6 +94,9 @@ func addCertToAgent(cert *ssh.Certificate, sshDir string) error {
}

func downloadCert(config ssh_ca_util.RequesterConfig, certRequestID string, sshDir string) (*ssh.Certificate, error) {
ssh_ca_util.StartTunnelIfNeeded(&config)
//fmt.Printf("get_cert downloadCert using signer url: %s", config.SignerUrl)

getResp, err := http.Get(config.SignerUrl + "cert/requests/" + certRequestID)
if err != nil {
return nil, fmt.Errorf("Didn't get a valid response: %s", err)
Expand All @@ -119,6 +122,7 @@ func downloadCert(config ssh_ca_util.RequesterConfig, certRequestID string, sshD
return nil, err
}
pubKeyPath = strings.Replace(pubKeyPath, ".pub", "-cert.pub", 1)
fmt.Printf("%s\n", getRespBuf)
err = ioutil.WriteFile(pubKeyPath, getRespBuf, 0644)
if err != nil {
fmt.Printf("Couldn't write certificate file to %s: %s\n", pubKeyPath, err)
Expand Down
4 changes: 3 additions & 1 deletion list_requests.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func listCertFlags() []cli.Flag {
if home == "" {
home = "/"
}
configPath := home + "/.ssh_ca/signer_config.json"
configPath := home + "/.ssh_ca/requester_config.json"

return []cli.Flag{
cli.StringFlag{
Expand Down Expand Up @@ -51,6 +51,8 @@ func listCerts(c *cli.Context) error {
return cli.NewExitError(fmt.Sprintf("%s", err), 1)
}
config := wrongTypeConfig.(ssh_ca_util.RequesterConfig)

ssh_ca_util.StartTunnelIfNeeded(&config)

getResp, err := http.Get(config.SignerUrl + "cert/requests")
if err != nil {
Expand Down
17 changes: 8 additions & 9 deletions request_cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ func requestCertFlags() []cli.Flag {
Name: "quiet",
Usage: "Print only the request id on success",
},
cli.BoolTFlag{
Name: "add-key",
Usage: "When set automatically call ssh-add if cert was auto-signed by server",
cli.BoolFlag{
Name: "no-get-key",
Usage: "When set don't automatically download the key",
},
cli.StringFlag{
Name: "ssh-dir",
Expand All @@ -94,6 +94,8 @@ func requestCert(c *cli.Context) error {
return cli.NewExitError(fmt.Sprintf("%s", err), 1)
}
config := wrongTypeConfig.(ssh_ca_util.RequesterConfig)

ssh_ca_util.StartTunnelIfNeeded(&config)

reason := c.String("reason")
if reason == "" {
Expand Down Expand Up @@ -176,15 +178,12 @@ func requestCert(c *cli.Context) error {
appendage = " auto-signed"
}
fmt.Printf("Cert request id: %s%s\n", requestID, appendage)
if signed && c.BoolT("add-key") {
cert, err := downloadCert(config, requestID, sshDir)
if err != nil {
return cli.NewExitError(fmt.Sprintf("%s", err), 1)
}
err = addCertToAgent(cert, sshDir)
if signed && !c.Bool("no-get-key") {
_, err := downloadCert(config, requestID, sshDir)
if err != nil {
return cli.NewExitError(fmt.Sprintf("%s", err), 1)
}
// add cert to agent didn't seem to work and seemed unnecessary
}
}
} else {
Expand Down
8 changes: 5 additions & 3 deletions sign_cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func signCertFlags() []cli.Flag {
if home == "" {
home = "/"
}
configPath := home + "/.ssh_ca/signer_config.json"
configPath := home + "/.ssh_ca/requester_config.json"

return []cli.Flag{
cli.StringFlag{
Expand All @@ -53,7 +53,7 @@ func signCertFlags() []cli.Flag {

func signCert(c *cli.Context) error {
configPath := c.String("config-file")
allConfig := make(map[string]ssh_ca_util.SignerConfig)
allConfig := make(map[string]ssh_ca_util.RequesterConfig)
err := ssh_ca_util.LoadConfig(configPath, &allConfig)
if err != nil {
return cli.NewExitError(fmt.Sprintf("Load Config failed: %s", err), 1)
Expand All @@ -71,7 +71,9 @@ func signCert(c *cli.Context) error {
if err != nil {
return cli.NewExitError(fmt.Sprintf("%s", err), 1)
}
config := wrongTypeConfig.(ssh_ca_util.SignerConfig)
config := wrongTypeConfig.(ssh_ca_util.RequesterConfig)

ssh_ca_util.StartTunnelIfNeeded(&config)

conn, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK"))
if err != nil {
Expand Down
27 changes: 5 additions & 22 deletions util/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ type RequesterConfig struct {
PublicKeyPath string `json:",omitempty"`
PublicKeyFingerprint string `json:",omitempty"`
SignerUrl string
SshBastion string `json:",omitempty"`
KeyFingerprint string `json:",omitempty"`
}

type SignerdConfig struct {
Expand All @@ -25,19 +27,14 @@ type SignerdConfig struct {
CriticalOptions map[string]string
}

type SignerConfig struct {
KeyFingerprint string
SignerUrl string
}

func LoadConfig(configPath string, environmentConfigs interface{}) error {
buf, err := ioutil.ReadFile(configPath)
if err != nil {
return err
}

switch configType := environmentConfigs.(type) {
case *map[string]RequesterConfig, *map[string]SignerConfig, *map[string]SignerdConfig:
case *map[string]RequesterConfig, *map[string]SignerdConfig:
return json.Unmarshal(buf, &environmentConfigs)
default:
return fmt.Errorf("oops: %T\n", configType)
Expand All @@ -56,24 +53,10 @@ func GetConfigForEnv(environment string, environmentConfigs interface{}) (interf
// lame way of extracting first and only key from a map?
}
}

config, ok := configs[environment]
if !ok {
return nil, fmt.Errorf("Requested environment not found in config file.")
}
return config, nil
case *map[string]SignerConfig:
configs := *environmentConfigs.(*map[string]SignerConfig)
if len(configs) > 1 && environment == "" {
return nil, fmt.Errorf("You must tell me which environment to use.")
}
if len(configs) == 1 && environment == "" {
for environment = range configs {
// lame way of extracting first and only key from a map?
}
}
config, ok := configs[environment]
if !ok {
return nil, fmt.Errorf("Requested environment not found in config file.")
return nil, fmt.Errorf("Requested environment not found in config file1.")
}
return config, nil
}
Expand Down
190 changes: 190 additions & 0 deletions util/sshtunnel.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
// copied from https://gist.github.com/svett/5d695dcc4cc6ad5dd275

package ssh_ca_util

import (
// "log"
// "bufio"
// "time"
"os"
"fmt"
"io"
"strings"
"strconv"
"net"
"sync"
"net/url"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)

type Endpoint struct {
Host string
Port int
}

func (endpoint *Endpoint) String() string {
return fmt.Sprintf("%s:%d", endpoint.Host, endpoint.Port)
}

type SSHtunnel struct {
Local *Endpoint
Server *Endpoint
Remote *Endpoint

Config *ssh.ClientConfig
}

func (tunnel *SSHtunnel) Start(out_port chan int) error {
listener, err := net.Listen("tcp", tunnel.Local.String())
if err != nil {
return err
}
//fmt.Println("Using local port:", listener.Addr().(*net.TCPAddr).Port)
//tunnel.LocalPort = listener.Addr().(*net.TCPAddr).Port
out_port <- listener.Addr().(*net.TCPAddr).Port
defer listener.Close()

for {
conn, err := listener.Accept()
if err != nil {
return err
}
go tunnel.forward(conn)
}
}

func (tunnel *SSHtunnel) forward(localConn net.Conn) {
serverConn, err := ssh.Dial("tcp", tunnel.Server.String(), tunnel.Config)
if err != nil {
fmt.Printf("Server dial error: %s\n", err)
return
}

remoteConn, err := serverConn.Dial("tcp", tunnel.Remote.String())
if err != nil {
fmt.Printf("Remote dial error: %s\n", err)
return
}

copyConn:=func(writer, reader net.Conn) {
_, err:= io.Copy(writer, reader)
if err != nil {
fmt.Printf("io.Copy error: %s", err)
}
}

go copyConn(localConn, remoteConn)
go copyConn(remoteConn, localConn)
}

func SSHAgent() ssh.AuthMethod {
if sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil {
return ssh.PublicKeysCallback(agent.NewClient(sshAgent).Signers)
}
return nil
}

var (
jobIsRunning bool
JobIsrunningMu sync.Mutex
)

func StartTunnelIfNeeded(config *RequesterConfig) {
if len(config.SshBastion) > 0 {

JobIsrunningMu.Lock()
start := !jobIsRunning
jobIsRunning = true
JobIsrunningMu.Unlock()
if start {
if !strings.HasPrefix(config.SshBastion, "ssh://") {
fmt.Printf("Bastion host must start with ssh://. Exiting\n")
os.Exit(1)
}

bastion_parsed, err := url.Parse(config.SshBastion)
if err != nil {
fmt.Printf("url.Parse error for SshBastion: %s", err)
}

// Check to see if it's a nonstardard port
host_parts := strings.Split(bastion_parsed.Host, ":")
var ssh_port int
ssh_port = 22
if len(host_parts) == 2 {
var err error
ssh_port, err = strconv.Atoi(host_parts[1])
if err != nil {
fmt.Printf("strconv.Atoi error: %s", err)
}
}

// Get remote end information
remote_parsed, err := url.Parse(config.SignerUrl)
if err != nil {
fmt.Printf("url.Parse error on SignerUrl: %s", err)
}
remote_parts := strings.Split(remote_parsed.Host, ":")
if len(remote_parts) != 2 {
fmt.Printf("Missing port for SignerUrl. Exiting")
os.Exit(1)
}
remote_port, err := strconv.Atoi(remote_parts[1])
if err != nil {
fmt.Printf("strconv.Atoi error: %s", err)
}

//fmt.Printf("config stuff: %s, %d\n", host_parts[0], ssh_port)
//fmt.Printf("starting tunnel config...\n")
localEndpoint := &Endpoint{
Host: "localhost",
Port: 0,
}

serverEndpoint := &Endpoint{
Host: host_parts[0],
Port: ssh_port,
}

remoteEndpoint := &Endpoint{
Host: remote_parts[0],
Port: remote_port,
}

sshConfig := &ssh.ClientConfig{
User: bastion_parsed.User.Username(),
Auth: []ssh.AuthMethod{
SSHAgent(),
},
// TODO: fix this to actually check the trusted hosts
// https://utcc.utoronto.ca/~cks/space/blog/programming/GoSSHHostKeyCheckingNotes
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
return nil
},
}

tunnel := &SSHtunnel{
Config: sshConfig,
Local: localEndpoint,
Server: serverEndpoint,
Remote: remoteEndpoint,
}

//fmt.Printf("starting tunnel...\n")

out_port_chan := make(chan int)
go tunnel.Start(out_port_chan)
var local_port int
local_port = <- out_port_chan
//fmt.Printf("Using local port: %d\n", local_port)

//fmt.Printf("doing normal stuff...\n")

config.SignerUrl = fmt.Sprintf("%s://localhost:%d/", remote_parsed.Scheme, local_port)
//fmt.Printf("sshtunnel using signer url: %s", config.SignerUrl)
// end new stuff
}
}
}