Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,24 @@
default = false;
};

unixSocket = mkOption {
type = nullOr lib.types.path;
default = null;
description = "Path to unix socket to listen on";
};

disableTCP = mkOption {
type = nullOr bool;
default = null;
description = "Disable the TCP Listeners on tsnet and tailscaled";
};

serverURL = mkOption {
type = nullOr str;
default = null;
description = "Server URL to use instead of the tailscale FDQN";
};

enableFunnel = mkOption {
type = bool;
default = false;
Expand Down Expand Up @@ -228,8 +246,11 @@
args = lib.cli.toGNUCommandLineShell { mkOptionName = k: "-${k}"; } {
hostname = cfg.settings.hostName;
port = cfg.settings.port;
server-url = cfg.settings.serverURL;
local-port = cfg.settings.localPort;
use-local-tailscaled = cfg.settings.useLocalTailscaled;
unix-socket = cfg.settings.unixSocket;
disable-tcp = cfg.settings.disableTCP;
funnel = cfg.settings.enableFunnel;
enable-sts = cfg.settings.enableSts;
log = cfg.settings.logLevel;
Expand Down
142 changes: 99 additions & 43 deletions tsidp-server.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"net/http"
"os"
"os/signal"
"strconv"
"strings"
"time"

Expand All @@ -39,8 +40,11 @@ var (
flagPort = flag.Int("port", 443, "port to listen on")
flagLocalPort = flag.Int("local-port", -1, "allow requests from localhost")
flagUseLocalTailscaled = flag.Bool("use-local-tailscaled", false, "use local tailscaled instead of tsnet")
flagUnixSocket = flag.String("unix-socket", "", "unix socket to listen on")
flagDisableTCP = flag.Bool("disable-tcp", false, "disable the tcp listener on tsnet/tailscaled")
flagFunnel = flag.Bool("funnel", false, "use Tailscale Funnel to make tsidp available on the public internet")
flagHostname = flag.String("hostname", "idp", "tsnet hostname to use instead of idp")
flagServerURL = flag.String("server-url", "", "server url to use instead of the tailscale FDQN.")
flagDir = flag.String("dir", "", "tsnet state directory; a default one will be created if not provided")
flagEnableSTS = flag.Bool("enable-sts", false, "enable OIDC STS token exchange support")

Expand Down Expand Up @@ -91,38 +95,41 @@ func main() {
slog.Error("getting local.Client status", slog.Any("error", err))
os.Exit(1)
}
portStr := fmt.Sprint(*flagPort)
anySuccess := false
for _, ip := range st.TailscaleIPs {
ln, err := net.Listen("tcp", net.JoinHostPort(ip.String(), portStr))
if err != nil {
slog.Warn("net.Listen failed", slog.String("ip", ip.String()), slog.Any("error", err))
continue

if !*flagDisableTCP {
portStr := fmt.Sprint(*flagPort)
anySuccess := false
for _, ip := range st.TailscaleIPs {
ln, err := net.Listen("tcp", net.JoinHostPort(ip.String(), portStr))
if err != nil {
slog.Warn("net.Listen failed", slog.String("ip", ip.String()), slog.Any("error", err))
continue
}
anySuccess = true
ln = tls.NewListener(ln, &tls.Config{
GetCertificate: lc.GetCertificate,
})
lns = append(lns, ln)
}
if !anySuccess {
slog.Error("failed to listen on any ip", slog.Any("ips", st.TailscaleIPs))
os.Exit(1)
}
anySuccess = true
ln = tls.NewListener(ln, &tls.Config{
GetCertificate: lc.GetCertificate,
})
lns = append(lns, ln)
}
if !anySuccess {
slog.Error("failed to listen on any ip", slog.Any("ips", st.TailscaleIPs))
os.Exit(1)
}

// tailscaled needs to be setting an HTTP header for funneled requests
// that older versions don't provide.
// TODO(naman): is this the correct check?
if *flagFunnel && !version.AtLeast(st.Version, "1.71.0") {
slog.Error("Local tailscaled not new enough to support -funnel. Update Tailscale or use tsnet mode.")
os.Exit(1)
}
cleanup, watcherChan, err = server.ServeOnLocalTailscaled(ctx, lc, st, uint16(*flagPort), *flagFunnel)
if err != nil {
slog.Error("could not serve on local tailscaled", slog.Any("error", err))
os.Exit(1)
// tailscaled needs to be setting an HTTP header for funneled requests
// that older versions don't provide.
// TODO(naman): is this the correct check?
if *flagFunnel && !version.AtLeast(st.Version, "1.71.0") {
slog.Error("Local tailscaled not new enough to support -funnel. Update Tailscale or use tsnet mode.")
os.Exit(1)
}
cleanup, watcherChan, err = server.ServeOnLocalTailscaled(ctx, lc, st, uint16(*flagPort), *flagFunnel)
if err != nil {
slog.Error("could not serve on local tailscaled", slog.Any("error", err))
os.Exit(1)
}
defer cleanup()
}
defer cleanup()
} else {
hostinfo.SetApp("tsidp")
ts := &tsnet.Server{
Expand All @@ -146,23 +153,26 @@ func main() {
slog.Error("failed to get local client", slog.Any("error", err))
os.Exit(1)
}
var ln net.Listener
if *flagFunnel {
if err := ipn.CheckFunnelAccess(uint16(*flagPort), st.Self); err != nil {
slog.Error("funnel access denied", slog.Any("error", err))

if !*flagDisableTCP {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As it stands, I don't really see much of a use-case for --unix-socket without --use-local-tailscaled. The nginx server will never be able to listen on the same domain as tsnet and there will have to be a local tailscaled session for ingress through nginx to be tagged correctly.

I added the disable flag here mostly for completeness.

var ln net.Listener
if *flagFunnel {
if err := ipn.CheckFunnelAccess(uint16(*flagPort), st.Self); err != nil {
slog.Error("funnel access denied", slog.Any("error", err))
os.Exit(1)
}
ln, err = ts.ListenFunnel("tcp", fmt.Sprintf(":%d", *flagPort))
} else {
ln, err = ts.ListenTLS("tcp", fmt.Sprintf(":%d", *flagPort))
}

if err != nil {
slog.Error("failed to listen", slog.Any("error", err))
os.Exit(1)
}
ln, err = ts.ListenFunnel("tcp", fmt.Sprintf(":%d", *flagPort))
} else {
ln, err = ts.ListenTLS("tcp", fmt.Sprintf(":%d", *flagPort))
}

if err != nil {
slog.Error("failed to listen", slog.Any("error", err))
os.Exit(1)
lns = append(lns, ln)
}

lns = append(lns, ln)
}

srv := server.New(
Expand All @@ -173,7 +183,11 @@ func main() {
*flagEnableSTS,
)

srv.SetServerURL(strings.TrimSuffix(st.Self.DNSName, "."), *flagPort)
if *flagServerURL != "" {
srv.SetServerURL(*flagServerURL, *flagPort)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be smarter to provide a list of additional TrustedOrigins, that way all the listeners can be supported simultaneously. I'm not very familiar with OAuth2 though, so I didn't want to touch the server.go file in this PR.

} else {
srv.SetServerURL(strings.TrimSuffix(st.Self.DNSName, "."), *flagPort)
}

// Load funnel clients from disk if they exist, regardless of whether funnel is enabled
// This ensures OIDC clients persist across restarts
Expand All @@ -184,6 +198,48 @@ func main() {

slog.Info("tsidp server started", slog.String("server_url", srv.ServerURL()))

if fdStr := os.Getenv("LISTEN_FDS"); fdStr != "" {
fds, err := strconv.Atoi(fdStr)
if err != nil {
slog.Error("failed to listen on systemd socket", slog.Any("error", err))
}
// systemd socket activation starts at fd 3
for fd := 3; fd < 3+fds; fd++ {
file := os.NewFile(uintptr(fd), "systemd-socket")
ln, err := net.FileListener(file)
if err != nil {
slog.Error("failed to listen on systemd socket", slog.Any("error", err))
}
lns = append(lns, ln)
}
} else if *flagUnixSocket != "" {
socketPath := *flagUnixSocket
info, err := os.Stat(socketPath)
if err == nil && (info.Mode()&os.ModeSocket) != 0 {
// A socket file already exists.
c, err := net.Dial("unix", socketPath)
if err == nil {
c.Close()
slog.Error("unix socket already in use")
os.Exit(1)
}

// It's a stale socket, so we can remove it.
os.Remove(socketPath)
}

ln, err := net.Listen("unix", *flagUnixSocket)
if err != nil {
slog.Error("failed to listen on unix socket", slog.Any("error", err))
os.Exit(1)
}
defer func() {
ln.Close() // TODO: the other listeners are not closed?
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also found this somewhat strange, I'm fairly sure that unix sockets must be closed, but I noticed that none of the tcp listeners are being closed. Are they cleaned up automatically by go when the program exits?

os.Remove(*flagUnixSocket)
}()
lns = append(lns, ln)
}

if *flagLocalPort != -1 {
loopbackURL := fmt.Sprintf("http://localhost:%d", *flagLocalPort)
slog.Info("Also running tsidp at loopback", slog.String("loopback_url", loopbackURL))
Expand Down