Skip to content

Commit

Permalink
cli.LoadConfig accepts config file now
Browse files Browse the repository at this point in the history
  • Loading branch information
huskyii committed Jun 5, 2022
1 parent adb55bc commit 0363e58
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 13 deletions.
20 changes: 12 additions & 8 deletions cmd/headscale/cli/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,19 @@ const (
HeadscaleDateTimeFormat = "2006-01-02 15:04:05"
)

func LoadConfig(path string) error {
viper.SetConfigName("config")
if path == "" {
viper.AddConfigPath("/etc/headscale/")
viper.AddConfigPath("$HOME/.headscale")
viper.AddConfigPath(".")
func LoadConfig(path string, isFile bool) error {
if isFile {
viper.SetConfigFile(path)
} else {
// For testing
viper.AddConfigPath(path)
viper.SetConfigName("config")
if path == "" {
viper.AddConfigPath("/etc/headscale/")
viper.AddConfigPath("$HOME/.headscale")
viper.AddConfigPath(".")
} else {
// For testing
viper.AddConfigPath(path)
}
}

viper.SetEnvPrefix("headscale")
Expand Down
2 changes: 1 addition & 1 deletion cmd/headscale/headscale.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func main() {
NoColor: !colors,
})

if err := cli.LoadConfig(""); err != nil {
if err := cli.LoadConfig("", false); err != nil {
log.Fatal().Caller().Err(err)
}

Expand Down
53 changes: 49 additions & 4 deletions cmd/headscale/headscale_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,51 @@ func (s *Suite) SetUpSuite(c *check.C) {
func (s *Suite) TearDownSuite(c *check.C) {
}

func (*Suite) TestConfigFileLoading(c *check.C) {
tmpDir, err := ioutil.TempDir("", "headscale")
if err != nil {
c.Fatal(err)
}
defer os.RemoveAll(tmpDir)

path, err := os.Getwd()
if err != nil {
c.Fatal(err)
}

cfgFile := filepath.Join(tmpDir, "config.yaml")

// Symlink the example config file
err = os.Symlink(
filepath.Clean(path+"/../../config-example.yaml"),
cfgFile,
)
if err != nil {
c.Fatal(err)
}

// Load example config, it should load without validation errors
err = cli.LoadConfig(cfgFile, true)
c.Assert(err, check.IsNil)

// Test that config file was interpreted correctly
c.Assert(viper.GetString("server_url"), check.Equals, "http://127.0.0.1:8080")
c.Assert(viper.GetString("listen_addr"), check.Equals, "0.0.0.0:8080")
c.Assert(viper.GetString("metrics_listen_addr"), check.Equals, "127.0.0.1:9090")
c.Assert(viper.GetString("db_type"), check.Equals, "sqlite3")
c.Assert(viper.GetString("db_path"), check.Equals, "/var/lib/headscale/db.sqlite")
c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "")
c.Assert(viper.GetString("tls_letsencrypt_listen"), check.Equals, ":http")
c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01")
c.Assert(viper.GetStringSlice("dns_config.nameservers")[0], check.Equals, "1.1.1.1")
c.Assert(
cli.GetFileMode("unix_socket_permission"),
check.Equals,
fs.FileMode(0o770),
)
c.Assert(viper.GetBool("logtail.enabled"), check.Equals, false)
}

func (*Suite) TestConfigLoading(c *check.C) {
tmpDir, err := ioutil.TempDir("", "headscale")
if err != nil {
Expand All @@ -49,7 +94,7 @@ func (*Suite) TestConfigLoading(c *check.C) {
}

// Load example config, it should load without validation errors
err = cli.LoadConfig(tmpDir)
err = cli.LoadConfig(tmpDir, false)
c.Assert(err, check.IsNil)

// Test that config file was interpreted correctly
Expand Down Expand Up @@ -92,7 +137,7 @@ func (*Suite) TestDNSConfigLoading(c *check.C) {
}

// Load example config, it should load without validation errors
err = cli.LoadConfig(tmpDir)
err = cli.LoadConfig(tmpDir, false)
c.Assert(err, check.IsNil)

dnsConfig, baseDomain := cli.GetDNSConfig()
Expand Down Expand Up @@ -125,7 +170,7 @@ func (*Suite) TestTLSConfigValidation(c *check.C) {
writeConfig(c, tmpDir, configYaml)

// Check configuration validation errors (1)
err = cli.LoadConfig(tmpDir)
err = cli.LoadConfig(tmpDir, false)
c.Assert(err, check.NotNil)
// check.Matches can not handle multiline strings
tmp := strings.ReplaceAll(err.Error(), "\n", "***")
Expand All @@ -150,6 +195,6 @@ func (*Suite) TestTLSConfigValidation(c *check.C) {
"---\nserver_url: \"http://127.0.0.1:8080\"\ntls_letsencrypt_hostname: \"example.com\"\ntls_letsencrypt_challenge_type: \"TLS-ALPN-01\"",
)
writeConfig(c, tmpDir, configYaml)
err = cli.LoadConfig(tmpDir)
err = cli.LoadConfig(tmpDir, false)
c.Assert(err, check.IsNil)
}

0 comments on commit 0363e58

Please sign in to comment.