Open
Description
TF: 2.0.0
TFP version: 0.8.0
Hi,
I want to write a VAE with Tensorflow-Probability. If I use tfpl.IndependentNormal at the end of the decoder, I get checkerboard artifacts. If I use instead tfd.Independent(tfd.Normal(...)) it works fine.
To show you what I mean, you can find the code here:
import tensorflow as tf
from tensorflow.keras import layers as tfl
import numpy as np
from tensorflow_probability import layers as tfpl
from tensorflow_probability import distributions as tfd
import matplotlib.pyplot as plt
# basic model
decoder = tf.keras.models.Sequential()
decoder.add(tfl.InputLayer(input_shape=[10]))
decoder.add(tfl.Reshape([1, 1, 10]))
decoder.add(tfl.UpSampling2D((2,2)))
decoder.add(tfl.Conv2D(100,(3,3),activation='selu',padding='same'))
decoder.add(tfl.UpSampling2D((2,2)))
decoder.add(tfl.Conv2D(100,(3,3),activation='selu',padding='same'))
decoder.add(tfl.UpSampling2D((2,2)))
decoder.add(tfl.Conv2D(100,(3,3),activation='selu',padding='same'))
decoder.add(tfl.UpSampling2D((2,2)))
decoder.add(tfl.Conv2D(100,(3,3),activation='selu',padding='same'))
decoder.add(tfl.UpSampling2D((2,2)))
plt.figure(figsize=(17,17))
# test input
input_values = np.array(np.random.random((1,10)),dtype=np.float32)
# 1. version: Pure TF.Conv-Layer
decoder1 = tf.keras.models.Sequential(decoder)
decoder1.add(tfl.Conv2D(1,(3,3),activation='selu',padding='same'))
plt.subplot(1,4,1)
plt.imshow(decoder1(input_values)[0,:,:,0])
plt.title('Pure TF.Conv-Layer')
# 2. version: Using tfpl.IndependentNormal
decoder2 = tf.keras.models.Sequential(decoder)
decoder2.add(tfl.Conv2D(2,(3,3),padding='same'))
decoder2.add(tfl.Flatten())
decoder2.add(tfpl.IndependentNormal((32,32,1)))
plt.subplot(1,4,2)
plt.imshow(decoder2(input_values).mean()[0,:,:,0])
plt.title('tfpl.IndependentNormal')
# 3. version: Using tfd.Independent(tfd.Normal(...))
plt.subplot(1,4,3)
plt.imshow(tfd.Independent(tfd.Normal(decoder1(input_values),decoder1(input_values)), 2).mean()[0,:,:,0])
plt.title('tfd.Independent(tfd.Normal(...))')
# 4. version: Using tfd.Independent(tfd.Normal(...)) in tfpl.DistributionLambda
def IndependentConvNormal():
return tfpl.DistributionLambda(
make_distribution_fn=lambda t:
tfd.Independent(
tfd.Normal(
loc=t[...,:1],
scale=tf.exp(t[...,1:]))))
decoder3 = tf.keras.models.Sequential(decoder)
decoder3.add(tfl.Conv2D(2,(3,3),padding='same'))
decoder3.add(IndependentConvNormal())
plt.subplot(1,4,4)
plt.imshow(decoder3(input_values).mean()[0,:,:,0])
plt.title('tfd.Independent(tfd.Normal(...))\nin tfpl.DistributionLambda')
plt.show()
Thanks for your help! :)
Activity