Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 105 additions & 53 deletions akd-client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ package main
import (
"fmt"
"os"
"io"
"net"
"regexp"
"strings"
"flag"
"errors"
"bufio"
"syscall"
"net/http"
"path/filepath"
"encoding/base64"
"golang.org/x/crypto/ssh"
Expand All @@ -22,6 +24,8 @@ type Config struct {
RecordName string
PubkeyStr string `toml:"Pubkey"`
pubkey *crypto.Key
Url string
AllowUrlFallback bool
AcceptUnverified bool
OverwriteAuthorizedKeys bool
AuthorizedKeysPath string
Expand Down Expand Up @@ -61,8 +65,13 @@ func loadConfig(path string) (Config, error) {
return config, err
}

// Ensure we've been given either a record name or URL
if config.RecordName == "" && config.Url == "" {
return config, errors.New("No value for RecordName or Url provided, will not be able to retrieve any keys")
}

// Parse the pubkey if we've been given one
if config.PubkeyStr != "" {
if config.RecordName != "" && config.PubkeyStr != "" {
key, err := crypto.NewKeyFromArmored(config.PubkeyStr)
if err != nil {
return config, errors.New("Failed to parse key from config file")
Expand Down Expand Up @@ -148,24 +157,16 @@ func validateAuthorizedKeys(keys string) (bool, error) {
return true, scanner.Err()
}

func main() {
// Parse CLI args
args := parseArgs()

// Load in config
config, err := loadConfig(args.ConfigPath)
if err != nil {
fmt.Fprintln(os.Stderr, "Failed to load config from " + args.ConfigPath + ": " + err.Error())
return
}

// Gets authorised keys from an AKD/S record in DNS.
// Returns the key list, whether it was verified with PGP, and any error encountered
func getAKDKeys(record_name string, pubkey *crypto.Key, accept_unverified bool) (string, bool, error) {
// Retrieve AKD/S records
records, _ := net.LookupTXT(config.RecordName)
records, _ := net.LookupTXT(record_name)

var record_type, key_blob, sig_blob string
for _, record := range records {
fmt.Fprintln(os.Stderr, "Record: " + record)

// Try parse the record out into its constituent blobs
var err error
record_type, key_blob, sig_blob, err = parseAKDRecord(record)
Expand All @@ -180,67 +181,118 @@ func main() {

// Make sure a record was chosen
if record_type == "" {
fmt.Fprintln(os.Stderr, "No suitable AKD/S record found")
return
}

fmt.Fprintln(os.Stderr, "Record type: " + record_type)
fmt.Fprint(os.Stderr, "Has keys? ")
if len(key_blob) > 0 {
fmt.Fprintln(os.Stderr, "Yes")
} else {
fmt.Fprintln(os.Stderr, "No")
}
fmt.Fprint(os.Stderr, "Has signature? ")
if len(sig_blob) > 0 {
fmt.Fprintln(os.Stderr, "Yes")
} else {
fmt.Fprintln(os.Stderr, "No")
return "", false, errors.New("No suitable AKD/S record found")
} else if record_type == "akd" && !accept_unverified {
return "", false, errors.New("Found AKD record but not accepting unverified records")
}

// Attempt to decode the key blob from base64
var key []byte
key, err = base64.StdEncoding.DecodeString(key_blob)
key, err := base64.StdEncoding.DecodeString(key_blob)
if err != nil {
fmt.Fprintln(os.Stderr, "Failed to decode key blob: " + err.Error())
return
return "", false, errors.New("Failed to decode key blob: " + err.Error())
}

// Do the same for the signature blob if this is an AKDS record
var sig []byte
if record_type == "akds" {
// Attempt to decode the signature blob
sig, err = base64.StdEncoding.DecodeString(sig_blob)
if err != nil {
fmt.Fprintln(os.Stderr, "Failed to decode signature: " + err.Error())
if !config.AcceptUnverified { return }
if !accept_unverified {
return "", false, errors.New("Failed to decode signature: " + err.Error())
} else {
fmt.Fprintln(os.Stderr, "Failed to decode signature: " + err.Error())
}
}

// Perform signature verification if we have a signature to verify
verified, err := verifySignature(key, sig, pubkey)
if err != nil && !accept_unverified {
return "", false, errors.New("Failed to verify AKDS signature: " + err.Error())
}

return string(key), (verified && err == nil), nil
} else {
return string(key), false, nil
}
}

// Validate the key blob to ensure it conforms with the OpenSSH authorized_keys format
var valid bool
valid, err = validateAuthorizedKeys(string(key))
if err != nil || !valid {
fmt.Fprintln(os.Stderr, "Failed to validate key blob format")
// Gets authorised keys from the given URL.
// Returns the key list and any error encountered
func getUrlKeys(url string) (string, error) {
response, err := http.Get(url)
if err != nil {
return "", errors.New("Error when requesting URL: " + err.Error())
}

if response.StatusCode >= 400 {
return "", errors.New("Unsuccessful response code when requesting URL: " + response.Status)
}

defer response.Body.Close()
body, err := io.ReadAll(response.Body)
if err != nil {
return "", errors.New("Failed to read response body: " + err.Error())
}

return string(body), nil
}

func main() {
// Parse CLI args
args := parseArgs()

// Load in config
config, err := loadConfig(args.ConfigPath)
if err != nil {
fmt.Fprintln(os.Stderr, "Failed to load config from " + args.ConfigPath + ": " + err.Error())
return
}

// Perform signature verification if this is an AKDS record and we have a signature to verify
if record_type == "akds" {
verified, err := verifySignature(key, sig, config.pubkey)
// Prioritise AKD if possible
var keys string
if config.RecordName != "" {
var verified bool
keys, verified, err = getAKDKeys(config.RecordName, config.pubkey, config.AcceptUnverified)
if err != nil {
fmt.Fprintln(os.Stderr, "Failed to verify AKDS signature: " + err.Error())
if !config.AcceptUnverified { return }
// Print out the error but don't return yet, we'll give the URL a try
fmt.Fprintln(os.Stderr, "Failed to get keys from AKD/S record: " + err.Error())

// Stop here if URL fallback is possible but not allowed
if config.Url != "" && !config.AllowUrlFallback {
fmt.Fprintln(os.Stderr, "URL specified but fallback not allowed")
return
}
} else {
if verified {
fmt.Fprintln(os.Stderr, "Successfully verified AKDS data")
} else {
fmt.Fprintln(os.Stderr, "Accepting unverified AKDS data")
}
}
}

if verified && err == nil {
fmt.Fprintln(os.Stderr, "Successfully verified AKDS data")
} else {
fmt.Fprintln(os.Stderr, "Accepting unverified AKDS data")
// Use URL if AKD not available
if config.RecordName == "" || (err != nil && config.Url != "") {
keys, err = getUrlKeys(config.Url)
if err != nil {
fmt.Fprintln(os.Stderr, "Failed to get keys from URL: " + err.Error())
return
}
} else if err != nil && config.Url == "" {
fmt.Fprintln(os.Stderr, "Failed to get any keys: " + err.Error())
return
}

// Validate the key blob to ensure it conforms with the OpenSSH authorized_keys format
var valid bool
valid, err = validateAuthorizedKeys(keys)
if err != nil || !valid {
fmt.Fprintln(os.Stderr, "Failed to validate key blob format")
return
}

// Print out for OpenSSH to handle
fmt.Print(string(key))
fmt.Print(keys)

// Try writing out to authorized_keys, if enabled
if config.OverwriteAuthorizedKeys {
Expand All @@ -262,7 +314,7 @@ func main() {
}

// Write out the keys
_, err = file.Write(key)
_, err = file.Write([]byte(keys))
if err != nil {
fmt.Fprintln(os.Stderr, "Failed to write authorized_keys file to "+path)
return
Expand Down
7 changes: 7 additions & 0 deletions config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ pubkey = """
# Missing signatures on AKDS records will always produce an error
acceptUnverified = false

# URL to pull keys from
# Note that the response is pulled as-is and will not be verified, unlike AKDS
url = "https://example.com/keys"

# Whether to allow fallback to pulling keys from URL if AKD/S fails
allowUrlFallback = true

# Whether to overwrite authorized_keys with AKD/S results
overwriteAuthorizedKeys = false

Expand Down