Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient Computing Error for DPMM Example #604

Open
shashankg7 opened this issue Oct 11, 2019 · 8 comments
Open

Gradient Computing Error for DPMM Example #604

shashankg7 opened this issue Oct 11, 2019 · 8 comments
Labels
keras tensorflow 2.0 Issues related to TF 2.0.

Comments

@shashankg7
Copy link

I am trying to run example code from tensorflow Probability for Dirichlet Process Mixture Model (https://github.com/tensorflow/probability/blob/master/tensorflow_probability/examples/jupyter_notebooks/Fitting_DPMM_Using_pSGLD.ipynb) . I am re-writing it in TF 2.0 following the tutorial from (https://brendanhasz.github.io/2019/06/12/tfp-gmm.html).

The model class is defined as:

class GaussianMixtureModelDP(tf.keras.Model):
"""A Bayesian Gaussian mixture model.

Assumes Gaussians' variances in each dimension are independent.

Parameters
----------
Nc : int > 0
    Number of mixture components.
Nd : int > 0
    Number of dimensions.
"""
  
def __init__(self, max_cluster_num, dims, batch_size):
    
    # Initialize
    super(GaussianMixtureModelDP, self).__init__()
    self.max_cluster_num = max_cluster_num
    self.dims = dims
    self.batch_size = batch_size
    
    # Variational distribution variables for means
    self.mix_probs = tf.Variable(
          initial_value=np.ones([max_cluster_num], dtype) / max_cluster_num, constraint=tf.nn.softmax)
    #self.mix_probs = tf.nn.softmax(self.mix_probs)
    self.loc = tf.Variable(
        initial_value=np.random.uniform(
      low=-9, #set around minimum value of sample value
      high=9, #set around maximum value of sample value
      size=[max_cluster_num, dims]))
    
    self.precision = tf.Variable(
      initial_value=
      np.ones([max_cluster_num, dims], dtype=dtype), constraint=tf.nn.softplus)
    #self.precision = tf.nn.softplus(self.precision)
    
    self.alpha = tf.Variable(
      initial_value=
      np.ones([1], dtype=dtype), constraint=tf.nn.softplus)
    #self.alpha = tf.nn.softplus(self.alpha)
    #self.training_vals = [self.mix_probs, self.alpha, self.loc, self.precision]

    
    
def call(self, x, sampling=True):
    """Compute losses given a batch of data.
    
    Parameters
    ----------
    x : tf.Tensor
        A batch of data
    sampling : bool
        Whether to sample from the variational posterior
        distributions (if True, the default), or just use the
        mean of the variational distributions (if False).
        
    Returns
    -------
    log_likelihoods : tf.Tensor
        Log likelihood for each sample
    kl_sum : tf.Tensor
        Sum of the KL divergences between the variational
        distributions and their priors
    """
    
    # The variational distributions
    rv_symmetric_dirichlet_process = tfd.Dirichlet(
    concentration=np.ones(self.max_cluster_num, dtype) * self.alpha / self.max_cluster_num,
    name='rv_sdp')
    # Sample from the variational distributions
    rv_loc = tfd.Independent(
       tfd.Normal(
        loc=tf.zeros([self.max_cluster_num, self.dims], dtype=dtype),
        scale=tf.ones([self.max_cluster_num, self.dims], dtype=dtype)),
    reinterpreted_batch_ndims=1,
    name='rv_loc')


    rv_precision = tfd.Independent(
        tfd.InverseGamma(
        concentration=np.ones([self.max_cluster_num, self.dims], dtype),
        scale=np.ones([self.max_cluster_num, self.dims], dtype)),
      reinterpreted_batch_ndims=1,
      name='rv_precision')

    rv_alpha = tfd.InverseGamma(
      concentration=np.ones([1], dtype=dtype),
    scale=np.ones([1]),
    name='rv_alpha')

    # Define mixture model
    rv_observations = tfd.MixtureSameFamily(
      mixture_distribution=tfd.Categorical(probs=self.mix_probs),
      components_distribution=tfd.MultivariateNormalDiag(
        loc=self.loc,
        scale_diag=self.precision))  
    
    log_prob_parts = [
    rv_loc.log_prob(self.loc) / num_samples,
    rv_precision.log_prob(self.precision) / num_samples,
    rv_alpha.log_prob(self.alpha) / num_samples,
    rv_symmetric_dirichlet_process.log_prob(self.mix_probs)[..., tf.newaxis]
    / num_samples,
    rv_observations.log_prob(x) / self.batch_size
    ]
    joint_log_probs = tf.reduce_sum(tf.concat(log_prob_parts, axis=-1), axis=-1)
    
    # Return both losses
    return joint_log_probs
  


The code for training is:


# Learning rates and decay
starter_learning_rate = 1e-6
end_learning_rate = 1e-10
decay_steps = 1e4

# Number of training steps
training_steps = 10000

# Mini-batch size
batch_size = 20

# Sample size for parameter posteriors
sample_size = 100

model = GaussianMixtureModelDP(30, 2, batch_size)

optimizer = tf.keras.optimizers.Adam(lr=1e-3 )

batch_size = 500
dataset = tf.data.Dataset.from_tensor_slices(
    (observations)).shuffle(10000).batch(batch_size)

@tf.function
def train_step(data):
    with tf.GradientTape() as tape:
        log_likelihoods = model(data)
    print(log_likelihoods)
    tvars = model.trainable_variables
    gradients = tape.gradient(-log_likelihoods, tvars)
    print(gradients)
    optimizer.apply_gradients(zip(gradients, tvars))

# Fit the model
EPOCHS = 1000
for epoch in range(EPOCHS):
    for data in dataset:
        print(data.shape)
        train_step(data)

When I run the code, I get the following error:

ValueError: in converted code:

<ipython-input-59-f1508ca04ee6>:9 train_step  *
    optimizer.apply_gradients(zip(gradients, tvars))
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:427 apply_gradients
    grads_and_vars = _filter_grads(grads_and_vars)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:975 _filter_grads
    ([v.name for _, v in grads_and_vars],))

ValueError: No gradients provided for any variable: ['Variable:0', 'Variable:0', 'Variable:0', 'Variable:0'].
@brianwa84
Copy link
Contributor

brianwa84 commented Oct 11, 2019 via email

@brianwa84
Copy link
Contributor

brianwa84 commented Oct 11, 2019 via email

@shashankg7
Copy link
Author

Thanks for your reply.

I am new to Tensorflow and TFP world, please excuse my lack of knowledge.

I tried the same model with tf.Module instead of tf.Model.Keras and removed the constraints to check end-to-end flow.

But I am still getting the same error, "No gradients provided".

@shashankg7
Copy link
Author

Update:

I implemented the following code:

class GaussianMixtureModelDP(tf.Module):
"""A Bayesian Gaussian mixture model.

Assumes Gaussians' variances in each dimension are independent.

Parameters
----------
Nc : int > 0
    Number of mixture components.
Nd : int > 0
    Number of dimensions.
"""
  
def __init__(self, max_cluster_num, dims, batch_size):
    
    # Initialize
    super(GaussianMixtureModelDP, self).__init__()
    self.max_cluster_num = max_cluster_num
    self.dims = dims
    self.batch_size = batch_size
    
    # Variational distribution variables for means
    self.mix_probs = tf.Variable(
          initial_value=np.ones([max_cluster_num], dtype) / max_cluster_num)#, constraint=tf.nn.softmax)
    #self.mix_probs = tf.nn.softmax(self.mix_probs)
    self.loc = tf.Variable(
        initial_value=np.random.uniform(
      low=-9, #set around minimum value of sample value
      high=9, #set around maximum value of sample value
      size=[max_cluster_num, dims]))
    
    self.precision = tf.Variable(
      initial_value=
      np.ones([max_cluster_num, dims], dtype=dtype))#, constraint=tf.nn.softplus)
    #self.precision = tf.nn.softplus(self.precision)
    
    self.alpha = tf.Variable(
      initial_value=
      np.ones([1], dtype=dtype))#, constraint=tf.nn.softplus)
    #self.alpha = tf.nn.softplus(self.alpha)
    self.training_vals = [self.mix_probs, self.alpha, self.loc, self.precision]

    
    
def call(self, x, sampling=True):
    """Compute losses given a batch of data.
    
    Parameters
    ----------
    x : tf.Tensor
        A batch of data
    sampling : bool
        Whether to sample from the variational posterior
        distributions (if True, the default), or just use the
        mean of the variational distributions (if False).
        
    Returns
    -------
    log_likelihoods : tf.Tensor
        Log likelihood for each sample
    kl_sum : tf.Tensor
        Sum of the KL divergences between the variational
        distributions and their priors
    """
    
    # The variational distributions
    rv_symmetric_dirichlet_process = tfd.Dirichlet(
    concentration=np.ones(self.max_cluster_num, dtype) * tfp.util.TransformedVariable(self.alpha, tfb.Softplus()) / self.max_cluster_num,
    name='rv_sdp')
    # Sample from the variational distributions
    rv_loc = tfd.Independent(
       tfd.Normal(
        loc=tf.zeros([self.max_cluster_num, self.dims], dtype=dtype),
        scale=tf.ones([self.max_cluster_num, self.dims], dtype=dtype)),
    reinterpreted_batch_ndims=1,
    name='rv_loc')


    rv_precision = tfd.Independent(
        tfd.InverseGamma(
        concentration=np.ones([self.max_cluster_num, self.dims], dtype),
        scale=np.ones([self.max_cluster_num, self.dims], dtype)),
      reinterpreted_batch_ndims=1,
      name='rv_precision')

    rv_alpha = tfd.InverseGamma(
      concentration=np.ones([1], dtype=dtype),
    scale=np.ones([1]),
    name='rv_alpha')

    # Define mixture model
    rv_observations = tfd.MixtureSameFamily(
      mixture_distribution=tfd.Categorical(probs=tfp.util.TransformedVariable(self.mix_probs, tfb.SoftmaxCentered())),
      components_distribution=tfd.MultivariateNormalDiag(
        loc=self.loc,
        scale_diag=tfp.util.TransformedVariable(self.precision, tfb.Softplus())))  
    
    log_prob_parts = [
    rv_loc.log_prob(self.loc) / num_samples,
    rv_precision.log_prob(tfp.util.TransformedVariable(self.precision, tfb.Softplus())) / num_samples,
    rv_alpha.log_prob(tfp.util.TransformedVariable(self.alpha, tfb.Softplus())) / num_samples,
    rv_symmetric_dirichlet_process.log_prob(tfp.util.TransformedVariable(self.mix_probs, tfb.SoftmaxCentered()))[..., tf.newaxis]
    / num_samples,
    rv_observations.log_prob(x) / self.batch_size
    ]
    joint_log_probs = tf.reduce_sum(tf.concat(log_prob_parts, axis=-1), axis=-1)
    
    # Return both losses
    return joint_log_probs

I am getting the same error.

@shashankg7
Copy link
Author

I am trying it with ADAM optimizer

batch_size = 500
dataset = tf.data.Dataset.from_tensor_slices(
    (observations)).shuffle(10000).batch(batch_size)
@tf.function
def train_step(data):
    with tf.GradientTape() as tape:
        log_likelihoods = model.call(data)
    print(log_likelihoods)
    tvars = model.training_vals
    gradients = tape.gradient(-log_likelihoods, tvars)
    print(gradients)
    optimizer.apply_gradients(zip(gradients, tvars))
# Fit the model
EPOCHS = 1000
for epoch in range(EPOCHS):
    for data in dataset:
        print(data.shape)
        train_step(data)

@shashankg7
Copy link
Author

@brianwa84 Hey Brian. Sorry to bug you again. But I am struggling to make this run. Tried running with different versions etc. Not been able to make it work. The above code snippet seems right to me, but the gradients are not computing. Would be great help if you can point out the mistake.

@AngelBerihuete
Copy link

AngelBerihuete commented Nov 6, 2019

Hi @shashankg7
Any progress with this issue?

@shashankg7
Copy link
Author

Hi @AngelBerihuete . No progress. Working with TFP is very frustrating. Planning to move to Pyro.

@dynamicwebpaige dynamicwebpaige added keras tensorflow 2.0 Issues related to TF 2.0. labels Nov 9, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
keras tensorflow 2.0 Issues related to TF 2.0.
Projects
None yet
Development

No branches or pull requests

4 participants