Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tandem-embeddings-with-freezable-weights #140

Open
david-thrower opened this issue Dec 16, 2023 · 0 comments
Open

tandem-embeddings-with-freezable-weights #140

david-thrower opened this issue Dec 16, 2023 · 0 comments

Comments

@david-thrower
Copy link
Owner

Kind of issue: The botteck on the tandem embeddings may be that the embedding converges to an optima well before dense layers do. Consequently, the embedding gradients will zero out. This will cascade to zero out all the other gradients due to the chain rule.

A solution to try may look like this:

import tensorflow as tf
import numpy as np

class TemporalEmbedding(tf.keras.layers.Layer):
    def __init__(self, vocab_size, embedding_dim, **kwargs):
        super(TemporalEmbedding, self).__init__(trainable=True)
        self.compute_gradient_for_n_epochs = 7
        self.train_counter = 0
        self.embedding_1 = tf.keras.layers.Embedding(vocab_size, embedding_dim, **kwargs)
        self.embedding_2 = tf.keras.layers.Embedding(vocab_size, embedding_dim, **kwargs)
        self.embedding_2.trainable = False
    def set_compute_gradient_for_n_epochs(self, n: int):
        self.compute_gradient_for_n_epochs = n
        print(f"Training this layer for only {self.compute_gradient_for_n_epochs} epochs")
    def call(self, inputs):
        print(f"starting state: {self.train_counter}")
        if self.train_counter < self.compute_gradient_for_n_epochs:
            print(f"Training weights for epoch {self.train_counter}")
            self.train_counter += 1
            return self.embedding_1(inputs)
        elif self.train_counter == self.compute_gradient_for_n_epochs:
            print(f"Setting trained weights to untrainable model (1) {self.train_counter}")
            self.train_counter += 1
            weights_0 =  self.embedding_1.get_weights()
            self.embedding_2.set_weights(weights_0)
            print("Returning weights from untrainable model")
            return self.embedding_2(inputs)
        else:
            print(f"Returning weights from untrainable model (2) {self.train_counter}")
            self.train_counter += 1
            return self.embedding_2(inputs)


input_layer = tf.keras.layers.Input(shape=(100,))
temporal_embedding_layer = TemporalEmbedding(vocab_size=10000, embedding_dim=64, input_length=10)
temporal_embedding_layer.set_compute_gradient_for_n_epochs(n=3)
temporal_embedding_layer_called = temporal_embedding_layer(input_layer)
flat = tf.keras.layers.Flatten()(temporal_embedding_layer_called)
output_layer = tf.keras.layers.Dense(10, activation='softmax')(flat)
model2 = tf.keras.Model(inputs=input_layer, outputs=output_layer)
opt = tf.keras.optimizers.Adam(learning_rate=0.001)
model2.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])


x_train = np.random.randint(10000, size=(200,100))
y_train = np.random.randint(2, size=(200,10))

model2.fit(x_train, y_train, epochs=20, batch_size=32)

Suggested Labels (If you don't know, that's ok): kind/enhancement

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

When branches are created from issues, their pull requests are automatically linked.

1 participant