Skip to content

Commit

Permalink
ipn/ipnlocal: do unexpired cert renewals in the background
Browse files Browse the repository at this point in the history
We were eagerly doing a synchronous renewal of the cert while
trying to serve traffic. Instead of that, just do the cert
renewal in the background and continue serving traffic as long
as the cert is still valid.

This regressed in c1ecae1 when
we introduced ARI support and were trying to make the experience
of `tailscale cert` better. However, that ended up regressing
the experience for tsnet as it would not always doing the renewal
synchronously.

Fixes tailscale#9783

Signed-off-by: Maisem Ali <maisem@tailscale.com>
  • Loading branch information
Maisem Ali authored and maisem committed Oct 12, 2023
1 parent 1a78f24 commit 24f322b
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 18 deletions.
27 changes: 13 additions & 14 deletions ipn/ipnlocal/cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,9 @@ var acmeDebug = envknob.RegisterBool("TS_DEBUG_ACME")
// ACME process. ACME process is used for new domain certs, existing expired
// certs or existing certs that should get renewed due to upcoming expiry.
//
// syncRenewal changes renewal behavior for existing certs that are still valid
// but need renewal. When syncRenewal is set, the method blocks until a new
// cert is issued. When syncRenewal is not set, existing cert is returned right
// away and renewal is kicked off in a background goroutine.
func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string, syncRenewal bool) (*TLSCertKeyPair, error) {
// If a cert is expired, it will be renewed synchronously otherwise it will be
// renewed asynchronously.
func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertKeyPair, error) {
if !validLookingCertDomain(domain) {
return nil, errors.New("invalid domain")
}
Expand All @@ -108,18 +106,16 @@ func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string, syncRenewa
}

if pair, err := getCertPEMCached(cs, domain, now); err == nil {
shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, pair)
if err != nil {
// If we got here, we have a valid unexpired cert.
// Check whether we should start an async renewal.
if shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, pair); err != nil {
logf("error checking for certificate renewal: %v", err)
} else if !shouldRenew {
return pair, nil
}
if !syncRenewal {
} else if shouldRenew {
logf("starting async renewal")
// Start renewal in the background.
go b.getCertPEM(context.Background(), cs, logf, traceACME, domain, now)
}
// Synchronous renewal happens below.
return pair, nil
}

pair, err := b.getCertPEM(ctx, cs, logf, traceACME, domain, now)
Expand All @@ -130,6 +126,8 @@ func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string, syncRenewa
return pair, nil
}

// shouldStartDomainRenewal reports whether the domain's cert should be renewed
// based on the current time, the cert's expiry, and the ARI check.
func (b *LocalBackend) shouldStartDomainRenewal(cs certStore, domain string, now time.Time, pair *TLSCertKeyPair) (bool, error) {
renewMu.Lock()
defer renewMu.Unlock()
Expand Down Expand Up @@ -365,8 +363,9 @@ type TLSCertKeyPair struct {
func keyFile(dir, domain string) string { return filepath.Join(dir, domain+".key") }
func certFile(dir, domain string) string { return filepath.Join(dir, domain+".crt") }

// getCertPEMCached returns a non-nil keyPair and true if a cached keypair for
// domain exists on disk in dir that is valid at the provided now time.
// getCertPEMCached returns a non-nil keyPair if a cached keypair for domain
// exists on disk in dir that is valid at the provided now time.
//
// If the keypair is expired, it returns errCertExpired.
// If the keypair doesn't exist, it returns ipn.ErrStateNotExist.
func getCertPEMCached(cs certStore, domain string, now time.Time) (p *TLSCertKeyPair, err error) {
Expand Down
2 changes: 1 addition & 1 deletion ipn/ipnlocal/cert_js.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ type TLSCertKeyPair struct {
CertPEM, KeyPEM []byte
}

func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string, syncRenewal bool) (*TLSCertKeyPair, error) {
func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertKeyPair, error) {
return nil, errors.New("not implemented for js/wasm")
}
4 changes: 2 additions & 2 deletions ipn/ipnlocal/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ func (b *LocalBackend) tcpHandlerForServe(dport uint16, srcAddr netip.AddrPort)
GetCertificate: func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
pair, err := b.GetCertPEM(ctx, sni, false)
pair, err := b.GetCertPEM(ctx, sni)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -757,7 +757,7 @@ func (b *LocalBackend) getTLSServeCertForPort(port uint16) func(hi *tls.ClientHe

ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
pair, err := b.GetCertPEM(ctx, hi.ServerName, false)
pair, err := b.GetCertPEM(ctx, hi.ServerName)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion ipn/localapi/cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func (h *Handler) serveCert(w http.ResponseWriter, r *http.Request) {
http.Error(w, "internal handler config wired wrong", 500)
return
}
pair, err := h.b.GetCertPEM(r.Context(), domain, true)
pair, err := h.b.GetCertPEM(r.Context(), domain)
if err != nil {
// TODO(bradfitz): 500 is a little lazy here. The errors returned from
// GetCertPEM (and everywhere) should carry info info to get whether
Expand Down

0 comments on commit 24f322b

Please sign in to comment.