|  | 
| 19 | 19 | import static com.google.common.base.Preconditions.checkNotNull; | 
| 20 | 20 | import static com.google.common.base.Preconditions.checkState; | 
| 21 | 21 | 
 | 
|  | 22 | +import io.grpc.xds.EnvoyServerProtoData.BaseTlsContext; | 
|  | 23 | +import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; | 
| 22 | 24 | import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; | 
| 23 | 25 | import io.netty.handler.ssl.SslContext; | 
| 24 | 26 | 
 | 
| 25 | 27 | /** | 
| 26 |  | - * Enables the CDS policy to initialize this object with the received {@link UpstreamTlsContext} & | 
| 27 |  | - * communicate it to the consumer i.e. {@link SdsProtocolNegotiators.ClientSdsProtocolNegotiator} | 
|  | 28 | + * Enables Client or server side to initialize this object with the received {@link BaseTlsContext} | 
|  | 29 | + * and communicate it to the consumer i.e. {@link SdsProtocolNegotiators} | 
| 28 | 30 |  * to lazily evaluate the {@link SslContextProvider}. The supplier prevents credentials leakage in | 
| 29 |  | - * cases where the user is not using xDS credentials but the CDS policy contains a non-default | 
| 30 |  | - * {@link UpstreamTlsContext}. | 
|  | 31 | + * cases where the user is not using xDS credentials but the client/server contains a non-default | 
|  | 32 | + * {@link BaseTlsContext}. | 
| 31 | 33 |  */ | 
| 32 | 34 | public final class SslContextProviderSupplier implements Closeable { | 
| 33 | 35 | 
 | 
| 34 |  | -  private final UpstreamTlsContext upstreamTlsContext; | 
|  | 36 | +  private final BaseTlsContext tlsContext; | 
| 35 | 37 |   private final TlsContextManager tlsContextManager; | 
| 36 | 38 |   private SslContextProvider sslContextProvider; | 
| 37 | 39 |   private boolean shutdown; | 
| 38 | 40 | 
 | 
| 39 | 41 |   public SslContextProviderSupplier( | 
| 40 |  | -      UpstreamTlsContext upstreamTlsContext, TlsContextManager tlsContextManager) { | 
| 41 |  | -    this.upstreamTlsContext = upstreamTlsContext; | 
|  | 42 | +      BaseTlsContext tlsContext, TlsContextManager tlsContextManager) { | 
|  | 43 | +    this.tlsContext = tlsContext; | 
| 42 | 44 |     this.tlsContextManager = tlsContextManager; | 
| 43 | 45 |   } | 
| 44 | 46 | 
 | 
| 45 |  | -  public UpstreamTlsContext getUpstreamTlsContext() { | 
| 46 |  | -    return upstreamTlsContext; | 
|  | 47 | +  public BaseTlsContext getTlsContext() { | 
|  | 48 | +    return tlsContext; | 
| 47 | 49 |   } | 
| 48 | 50 | 
 | 
| 49 | 51 |   /** Updates SslContext via the passed callback. */ | 
| 50 | 52 |   public synchronized void updateSslContext(final SslContextProvider.Callback callback) { | 
| 51 | 53 |     checkNotNull(callback, "callback"); | 
| 52 | 54 |     checkState(!shutdown, "Supplier is shutdown!"); | 
| 53 | 55 |     if (sslContextProvider == null) { | 
| 54 |  | -      sslContextProvider = | 
| 55 |  | -          tlsContextManager.findOrCreateClientSslContextProvider(upstreamTlsContext); | 
|  | 56 | +      sslContextProvider = getSslContextProvider(); | 
| 56 | 57 |     } | 
| 57 | 58 |     // we want to increment the ref-count so call findOrCreate again... | 
| 58 |  | -    final SslContextProvider toRelease = | 
| 59 |  | -        tlsContextManager.findOrCreateClientSslContextProvider(upstreamTlsContext); | 
|  | 59 | +    final SslContextProvider toRelease = getSslContextProvider(); | 
| 60 | 60 |     sslContextProvider.addCallback( | 
| 61 | 61 |         new SslContextProvider.Callback(callback.getExecutor()) { | 
| 62 | 62 | 
 | 
| 63 | 63 |           @Override | 
| 64 | 64 |           public void updateSecret(SslContext sslContext) { | 
| 65 | 65 |             callback.updateSecret(sslContext); | 
| 66 |  | -            tlsContextManager.releaseClientSslContextProvider(toRelease); | 
|  | 66 | +            releaseSslContextProvider(toRelease); | 
| 67 | 67 |           } | 
| 68 | 68 | 
 | 
| 69 | 69 |           @Override | 
| 70 | 70 |           public void onException(Throwable throwable) { | 
| 71 | 71 |             callback.onException(throwable); | 
| 72 |  | -            tlsContextManager.releaseClientSslContextProvider(toRelease); | 
|  | 72 | +            releaseSslContextProvider(toRelease); | 
| 73 | 73 |           } | 
| 74 | 74 |         }); | 
| 75 | 75 |   } | 
| 76 | 76 | 
 | 
| 77 |  | -  /** Called by {@link io.grpc.xds.CdsLoadBalancer} when upstreamTlsContext changes. */ | 
|  | 77 | +  private void releaseSslContextProvider(SslContextProvider toRelease) { | 
|  | 78 | +    if (tlsContext instanceof UpstreamTlsContext) { | 
|  | 79 | +      tlsContextManager.releaseClientSslContextProvider(toRelease); | 
|  | 80 | +    } else { | 
|  | 81 | +      tlsContextManager.releaseServerSslContextProvider(toRelease); | 
|  | 82 | +    } | 
|  | 83 | +  } | 
|  | 84 | + | 
|  | 85 | +  private SslContextProvider getSslContextProvider() { | 
|  | 86 | +    return tlsContext instanceof UpstreamTlsContext | 
|  | 87 | +        ? tlsContextManager.findOrCreateClientSslContextProvider((UpstreamTlsContext) tlsContext) | 
|  | 88 | +        : tlsContextManager.findOrCreateServerSslContextProvider((DownstreamTlsContext) tlsContext); | 
|  | 89 | +  } | 
|  | 90 | + | 
|  | 91 | +  /** Called by consumer when tlsContext changes. */ | 
| 78 | 92 |   @Override | 
| 79 | 93 |   public synchronized void close() { | 
| 80 |  | -    if (sslContextProvider != null) { | 
|  | 94 | +    if (tlsContext instanceof UpstreamTlsContext) { | 
| 81 | 95 |       tlsContextManager.releaseClientSslContextProvider(sslContextProvider); | 
|  | 96 | +    } else { | 
|  | 97 | +      tlsContextManager.releaseServerSslContextProvider(sslContextProvider); | 
| 82 | 98 |     } | 
| 83 | 99 |     shutdown = true; | 
| 84 | 100 |   } | 
|  | 
0 commit comments