Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/go-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
- name: Install dependencies
run: |
sudo apt-get update
sudo apt-get install -y libbtrfs-dev build-essential
sudo apt-get install -y libbtrfs-dev build-essential apg jq

- uses: actions/setup-go@v3
with:
Expand Down
4 changes: 2 additions & 2 deletions cmd/garm-cli/cmd/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ func init() {

func formatGithubCredentials(creds []params.GithubCredentials) {
t := table.NewWriter()
header := table.Row{"Name", "Description", "Base URL", "API URL", "Upload URL"}
header := table.Row{"Name", "Description", "Base URL", "API URL", "Upload URL", "Type"}
t.AppendHeader(header)
for _, val := range creds {
t.AppendRow(table.Row{val.Name, val.Description, val.BaseURL, val.APIBaseURL, val.UploadBaseURL})
t.AppendRow(table.Row{val.Name, val.Description, val.BaseURL, val.APIBaseURL, val.UploadBaseURL, val.AuthType})
t.AppendSeparator()
}
fmt.Println(t.Render())
Expand Down
198 changes: 171 additions & 27 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,34 @@
package config

import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"log/slog"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"time"

"github.com/BurntSushi/toml"
"github.com/bradleyfalzon/ghinstallation/v2"
zxcvbn "github.com/nbutton23/zxcvbn-go"
"github.com/pkg/errors"
"golang.org/x/oauth2"

"github.com/cloudbase/garm/params"
"github.com/cloudbase/garm/util/appdefaults"
)

type (
DBBackendType string
LogLevel string
LogFormat string
DBBackendType string
LogLevel string
LogFormat string
GithubAuthType string
)

const (
Expand Down Expand Up @@ -67,6 +73,13 @@ const (
FormatJSON LogFormat = "json"
)

const (
// GithubAuthTypePAT is the OAuth token based authentication
GithubAuthTypePAT GithubAuthType = "pat"
// GithubAuthTypeApp is the GitHub App based authentication
GithubAuthTypeApp GithubAuthType = "app"
)

// NewConfig returns a new Config
func NewConfig(cfgFile string) (*Config, error) {
var config Config
Expand All @@ -93,35 +106,35 @@ type Config struct {
// Validate validates the config
func (c *Config) Validate() error {
if err := c.APIServer.Validate(); err != nil {
return errors.Wrap(err, "validating APIServer config")
return fmt.Errorf("error validating apiserver config: %w", err)
}
if err := c.Database.Validate(); err != nil {
return errors.Wrap(err, "validating database config")
return fmt.Errorf("error validating database config: %w", err)
}

if err := c.Default.Validate(); err != nil {
return errors.Wrap(err, "validating default section")
return fmt.Errorf("error validating default config: %w", err)
}

for _, gh := range c.Github {
if err := gh.Validate(); err != nil {
return errors.Wrap(err, "validating github config")
return fmt.Errorf("error validating github config: %w", err)
}
}

if err := c.JWTAuth.Validate(); err != nil {
return errors.Wrap(err, "validating jwt config")
return fmt.Errorf("error validating jwt_auth config: %w", err)
}

if err := c.Logging.Validate(); err != nil {
return errors.Wrap(err, "validating logging config")
return fmt.Errorf("error validating logging config: %w", err)
}

providerNames := map[string]int{}

for _, provider := range c.Providers {
if err := provider.Validate(); err != nil {
return errors.Wrap(err, "validating provider")
return fmt.Errorf("error validating provider %s: %w", provider.Name, err)
}
providerNames[provider.Name]++
}
Expand Down Expand Up @@ -204,15 +217,54 @@ func (d *Default) Validate() error {
}
_, err := url.Parse(d.CallbackURL)
if err != nil {
return errors.Wrap(err, "validating callback_url")
return fmt.Errorf("invalid callback_url: %w", err)
}

if d.MetadataURL == "" {
return fmt.Errorf("missing metadata-url")
return fmt.Errorf("missing metadata_url")
}

if _, err := url.Parse(d.MetadataURL); err != nil {
return errors.Wrap(err, "validating metadata_url")
return fmt.Errorf("invalid metadata_url: %w", err)
}

return nil
}

type GithubPAT struct {
OAuth2Token string `toml:"oauth2_token" json:"oauth2-token"`
}

type GithubApp struct {
AppID int64 `toml:"app_id" json:"app-id"`
PrivateKeyPath string `toml:"private_key_path" json:"private-key-path"`
InstallationID int64 `toml:"installation_id" json:"installation-id"`
}

func (a *GithubApp) Validate() error {
if a.AppID == 0 {
return fmt.Errorf("missing app_id")
}
if a.PrivateKeyPath == "" {
return fmt.Errorf("missing private_key_path")
}
if a.InstallationID == 0 {
return fmt.Errorf("missing installation_id")
}

if _, err := os.Stat(a.PrivateKeyPath); err != nil {
return fmt.Errorf("error accessing private_key_path: %w", err)
}
// Read the private key as bytes
keyBytes, err := os.ReadFile(a.PrivateKeyPath)
if err != nil {
return fmt.Errorf("reading private_key_path: %w", err)
}
block, _ := pem.Decode(keyBytes)
// Parse the private key as PCKS1
_, err = x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return fmt.Errorf("parsing private_key_path: %w", err)
}

return nil
Expand All @@ -221,16 +273,22 @@ func (d *Default) Validate() error {
// Github hold configuration options specific to interacting with github.
// Currently that is just a OAuth2 personal token.
type Github struct {
Name string `toml:"name" json:"name"`
Description string `toml:"description" json:"description"`
Name string `toml:"name" json:"name"`
Description string `toml:"description" json:"description"`
// OAuth2Token is the personal access token used to authenticate with the
// github API. This is deprecated and will be removed in the future.
// Use the PAT section instead.
OAuth2Token string `toml:"oauth2_token" json:"oauth2-token"`
APIBaseURL string `toml:"api_base_url" json:"api-base-url"`
UploadBaseURL string `toml:"upload_base_url" json:"upload-base-url"`
BaseURL string `toml:"base_url" json:"base-url"`
// CACertBundlePath is the path on disk to a CA certificate bundle that
// can validate the endpoints defined above. Leave empty if not using a
// self signed certificate.
CACertBundlePath string `toml:"ca_cert_bundle" json:"ca-cert-bundle"`
CACertBundlePath string `toml:"ca_cert_bundle" json:"ca-cert-bundle"`
AuthType GithubAuthType `toml:"auth_type" json:"auth-type"`
PAT GithubPAT `toml:"pat" json:"pat"`
App GithubApp `toml:"app" json:"app"`
}

func (g *Github) APIEndpoint() string {
Expand All @@ -246,12 +304,12 @@ func (g *Github) CACertBundle() ([]byte, error) {
return nil, nil
}
if _, err := os.Stat(g.CACertBundlePath); err != nil {
return nil, errors.Wrap(err, "accessing CA bundle")
return nil, fmt.Errorf("error accessing ca_cert_bundle: %w", err)
}

contents, err := os.ReadFile(g.CACertBundlePath)
if err != nil {
return nil, errors.Wrap(err, "reading CA bundle")
return nil, fmt.Errorf("reading ca_cert_bundle: %w", err)
}

roots := x509.NewCertPool()
Expand Down Expand Up @@ -280,13 +338,99 @@ func (g *Github) BaseEndpoint() string {
}

func (g *Github) Validate() error {
if g.OAuth2Token == "" {
return fmt.Errorf("missing github oauth2 token")
if g.Name == "" {
return fmt.Errorf("missing credentials name")
}
if g.Description == "" {
return fmt.Errorf("missing credentials description")
}

if g.APIBaseURL != "" {
if _, err := url.ParseRequestURI(g.APIBaseURL); err != nil {
return fmt.Errorf("invalid api_base_url: %w", err)
}
}

if g.UploadBaseURL != "" {
if _, err := url.ParseRequestURI(g.UploadBaseURL); err != nil {
return fmt.Errorf("invalid upload_base_url: %w", err)
}
}

if g.BaseURL != "" {
if _, err := url.ParseRequestURI(g.BaseURL); err != nil {
return fmt.Errorf("invalid base_url: %w", err)
}
}

switch g.AuthType {
case GithubAuthTypeApp:
if err := g.App.Validate(); err != nil {
return fmt.Errorf("invalid github app config: %w", err)
}
default:
if g.OAuth2Token == "" && g.PAT.OAuth2Token == "" {
return fmt.Errorf("missing github oauth2 token")
}
if g.OAuth2Token != "" {
slog.Warn("the github.oauth2_token option is deprecated, please use the PAT section")
}
}

return nil
}

func (g *Github) HTTPClient(ctx context.Context) (*http.Client, error) {
if err := g.Validate(); err != nil {
return nil, fmt.Errorf("invalid github config: %w", err)
}
var roots *x509.CertPool
caBundle, err := g.CACertBundle()
if err != nil {
return nil, fmt.Errorf("fetching CA cert bundle: %w", err)
}
if caBundle != nil {
roots = x509.NewCertPool()
ok := roots.AppendCertsFromPEM(caBundle)
if !ok {
return nil, fmt.Errorf("failed to parse CA cert")
}
}
// nolint:golangci-lint,gosec,godox
// TODO: set TLS MinVersion
httpTransport := &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: roots,
},
}

var tc *http.Client
switch g.AuthType {
case GithubAuthTypeApp:
itr, err := ghinstallation.NewKeyFromFile(httpTransport, g.App.AppID, g.App.InstallationID, g.App.PrivateKeyPath)
if err != nil {
return nil, fmt.Errorf("failed to create github app installation transport: %w", err)
}

tc = &http.Client{Transport: itr}
default:
httpClient := &http.Client{Transport: httpTransport}
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)

token := g.PAT.OAuth2Token
if token == "" {
token = g.OAuth2Token
}

ts := oauth2.StaticTokenSource(
&oauth2.Token{AccessToken: token},
)
tc = oauth2.NewClient(ctx, ts)
}

return tc, nil
}

// Provider holds access information for a particular provider.
// A provider offers compute resources on which we spin up self hosted runners.
type Provider struct {
Expand All @@ -308,7 +452,7 @@ func (p *Provider) Validate() error {
switch p.ProviderType {
case params.ExternalProvider:
if err := p.External.Validate(); err != nil {
return errors.Wrap(err, "validating external provider info")
return fmt.Errorf("invalid external provider config: %w", err)
}
default:
return fmt.Errorf("unknown provider type: %s", p.ProviderType)
Expand Down Expand Up @@ -371,11 +515,11 @@ func (d *Database) Validate() error {
switch d.DbBackend {
case MySQLBackend:
if err := d.MySQL.Validate(); err != nil {
return errors.Wrap(err, "validating mysql config")
return fmt.Errorf("validating mysql config: %w", err)
}
case SQLiteBackend:
if err := d.SQLite.Validate(); err != nil {
return errors.Wrap(err, "validating sqlite3 config")
return fmt.Errorf("validating sqlite3 config: %w", err)
}
default:
return fmt.Errorf("invalid database backend: %s", d.DbBackend)
Expand All @@ -399,7 +543,7 @@ func (s *SQLite) Validate() error {

parent := filepath.Dir(s.DBFile)
if _, err := os.Stat(parent); err != nil {
return errors.Wrapf(err, "accessing db_file parent dir: %s", parent)
return fmt.Errorf("parent directory of db_file does not exist: %w", err)
}
return nil
}
Expand Down Expand Up @@ -513,7 +657,7 @@ func (a *APIServer) BindAddress() string {
func (a *APIServer) Validate() error {
if a.UseTLS {
if err := a.TLSConfig.Validate(); err != nil {
return errors.Wrap(err, "TLS validation failed")
return fmt.Errorf("invalid tls config: %w", err)
}
}
if a.Port > 65535 || a.Port < 1 {
Expand Down Expand Up @@ -558,7 +702,7 @@ func (d *timeToLive) Duration() time.Duration {
func (d *timeToLive) UnmarshalText(text []byte) error {
_, err := time.ParseDuration(string(text))
if err != nil {
return errors.Wrap(err, "parsing time_to_live")
return fmt.Errorf("invalid duration: %w", err)
}

*d = timeToLive(text)
Expand All @@ -574,7 +718,7 @@ type JWTAuth struct {
// Validate validates the JWTAuth config
func (j *JWTAuth) Validate() error {
if _, err := j.TimeToLive.ParseDuration(); err != nil {
return errors.Wrap(err, "parsing duration")
return fmt.Errorf("invalid time_to_live: %w", err)
}

if j.Secret == "" {
Expand Down
Loading