-
Notifications
You must be signed in to change notification settings - Fork 95
Open
Description
Hi,
When trying to get predictions of Lattice Models on more than one batch of data at once, Errors are raised. This is a nice feature to efficiently get predictions, and is present in basic Neural Network Keras models;
find some examples in this colab.
As far as I can tell from looking at API docs + source code, this should be related to the inputs admitted by PWC layers, but I wonder if there is an easy way around.
In particular, this piece of code captures what I would like to get (and retrieves an error when calling on batched_inputs):
class LatticeModel(tf.keras.Model):
def __init__(self, nodes=[2,2], nkeypoints=100):
super(LatticeModel,self).__init__()
self.combined_calibrators = tfl.layers.ParallelCombination()
for ind,i in enumerate(range(2)):
calibration_layer = tfl.layers.PWLCalibration(input_keypoints=np.linspace(0,1,nkeypoints),output_min=0.0, output_max=nodes[ind])
self.combined_calibrators.append(calibration_layer)
self.lattice = tfl.layers.Lattice(lattice_sizes=nodes,interpolation="simplex")
def call(self, x):
rescaled = self.combined_calibrators(x)
feat = self.lattice(rescaled)
return feat
#we define some input data
x1 = np.random.randn(100,1).astype(np.float32)
x2 = np.random.randn(100,1).astype(np.float32)
inputs = tf.concat([x1,x2], axis=-1)
#we initialize out model, and feed it with a batch of size 100
model = LatticeModel()
model(inputs)
### now we would like to efficiently predict the output of the lattice model on many batches of data at once (in this case 2)
batched_inputs = np.random.randn(2,100,1)
model(batched_inputs)
Thanks a lot!
Matías.
Metadata
Metadata
Assignees
Labels
No labels