You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Kind of issue: The botteck on the tandem embeddings may be that the embedding converges to an optima well before dense layers do. Consequently, the embedding gradients will zero out. This will cascade to zero out all the other gradients due to the chain rule.
A solution to try may look like this:
import tensorflow as tf
import numpy as np
classTemporalEmbedding(tf.keras.layers.Layer):
def__init__(self, vocab_size, embedding_dim, **kwargs):
super(TemporalEmbedding, self).__init__(trainable=True)
self.compute_gradient_for_n_epochs=7self.train_counter=0self.embedding_1=tf.keras.layers.Embedding(vocab_size, embedding_dim, **kwargs)
self.embedding_2=tf.keras.layers.Embedding(vocab_size, embedding_dim, **kwargs)
self.embedding_2.trainable=Falsedefset_compute_gradient_for_n_epochs(self, n: int):
self.compute_gradient_for_n_epochs=nprint(f"Training this layer for only {self.compute_gradient_for_n_epochs} epochs")
defcall(self, inputs):
print(f"starting state: {self.train_counter}")
ifself.train_counter<self.compute_gradient_for_n_epochs:
print(f"Training weights for epoch {self.train_counter}")
self.train_counter+=1returnself.embedding_1(inputs)
elifself.train_counter==self.compute_gradient_for_n_epochs:
print(f"Setting trained weights to untrainable model (1) {self.train_counter}")
self.train_counter+=1weights_0=self.embedding_1.get_weights()
self.embedding_2.set_weights(weights_0)
print("Returning weights from untrainable model")
returnself.embedding_2(inputs)
else:
print(f"Returning weights from untrainable model (2) {self.train_counter}")
self.train_counter+=1returnself.embedding_2(inputs)
input_layer=tf.keras.layers.Input(shape=(100,))
temporal_embedding_layer=TemporalEmbedding(vocab_size=10000, embedding_dim=64, input_length=10)
temporal_embedding_layer.set_compute_gradient_for_n_epochs(n=3)
temporal_embedding_layer_called=temporal_embedding_layer(input_layer)
flat=tf.keras.layers.Flatten()(temporal_embedding_layer_called)
output_layer=tf.keras.layers.Dense(10, activation='softmax')(flat)
model2=tf.keras.Model(inputs=input_layer, outputs=output_layer)
opt=tf.keras.optimizers.Adam(learning_rate=0.001)
model2.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])
x_train=np.random.randint(10000, size=(200,100))
y_train=np.random.randint(2, size=(200,10))
model2.fit(x_train, y_train, epochs=20, batch_size=32)
Suggested Labels (If you don't know, that's ok): kind/enhancement
The text was updated successfully, but these errors were encountered:
Kind of issue: The botteck on the tandem embeddings may be that the embedding converges to an optima well before dense layers do. Consequently, the embedding gradients will zero out. This will cascade to zero out all the other gradients due to the chain rule.
A solution to try may look like this:
import tensorflow as tf
import numpy as np
Suggested Labels (If you don't know, that's ok): kind/enhancement
The text was updated successfully, but these errors were encountered: