Skip to content
This repository was archived by the owner on Jul 10, 2025. It is now read-only.

RFC: Sparse Domain Isolation for Supporting large-scale Sparse Weights Training. #237

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

rhdong
Copy link
Member

@rhdong rhdong commented Apr 25, 2020

Sparse Domain Isolation for supporting large-scale Recommender Systems.

Status Draft
Author(s) Haidong Rong (hudsonrong@tencent.com) Yafei Zhang(kimmyzhang@tencent.com) Jiandong Wang(adnywang@tencent.com) Chuan Cheng(chuancheng@tencent.com)
Reviewers(s) Alexandre Passos(alexandre.tp@gmail.com) Bairen Yi(yibairen.byron@bytedance.com)
Sponsor Yuefeng Zhou (yuefengz@google.com) Zhenyu Tan (tanzheny@google.com)
Updated 2020-09-16

@yuefengz @byronyi
Hi,
This is the RFC of Sparse Domain Isolation for supporting large-scale Recommender Systems.
It ’s still a draft. We will update the latest content as soon as possible, we can improve on this basis. In order to push forward as soon as possible, I first submitted here but the owners are everyone who participated in the discussion in the past, and we will complete the list later.

@rhdong rhdong changed the title first commit for sparse domain isolation. RFC: Sparse Domain Isolation for Supporting large-scale Sparse Weights Training. Apr 29, 2020
@yuefengz
Copy link
Contributor

@byronyi If we are going to contribute to addon first, do we need a RFC here?

@smilingday
Copy link
Contributor

Since this RFC targets for SIG AddOns, add SIG AddOns leads
@facaiy @seanpmorgan
and TF sponsor @karmel
as reviewers.

Copy link
Contributor

@alextp alextp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a very interesting proposal.

From a high level view (and I'm probably wrong) it looks like it proposes a new type of variable and a new type of optimizer which can update that variable. Given that this is the case I think we can implement this in addons or some other SIG package as long as there are APIs in core TF to ensure that this variable can declare itself checkpointable, be tracked by something like tf.Module / keras.Model (so you can do model.trainable_sparse_variables), and maybe be automatically watched via the gradient tape.

Can you expand the document to clarify the details of these changes to existing parts of TF as opposed to most of the content which is on the new types?

Thanks!

@byronyi
Copy link
Contributor

byronyi commented Apr 30, 2020

@byronyi If we are going to contribute to addon first, do we need a RFC here?

I guess the design was originally targeted to TF core.

As @alextp said, if part of it still requires changes to TF core, then we still need a (probably smaller) RFC here.

@alextp
Copy link
Contributor

alextp commented Apr 30, 2020 via email

@rhdong
Copy link
Member Author

rhdong commented May 4, 2020

That's a very interesting proposal.

From a high level view (and I'm probably wrong) it looks like it proposes a new type of variable and a new type of optimizer which can update that variable. Given that this is the case I think we can implement this in addons or some other SIG package as long as there are APIs in core TF to ensure that this variable can declare itself checkpointable, be tracked by something like tf.Module / keras.Model (so you can do model.trainable_sparse_variables), and maybe be automatically watched via the gradient tape.

Can you expand the document to clarify the details of these changes to existing parts of TF as opposed to most of the content which is on the new types?

Thanks!

Thank you,
In fact, My initial idea was to encapsulate some kind of ResourceVariable backed Hashtable, as we know TF is not good at training any non tf.Variable. I reuse lookup.MutableHashTable because I don't like to write a new hash lib in TF , especially, lookup.XX support checkpointable and deployable on tf.distribute.Server.
Here is the compare based on v1.15.2 shows that the range of core effected by the RFC:
https://github.com/tensorflow/tensorflow/compare/v1.15.2...rhdong:rfc?expand=1

The main changes:

  1. supporting the random initiallizer on lookup.MutableHashTable.Find
  2. Four stateful optimizers(Adagrad, Adam, FTRL, Momentum) adaptation.(Maybe cancelled in new scheme)

Thanks!

@alextp
Copy link
Contributor

alextp commented May 4, 2020 via email

@rhdong
Copy link
Member Author

rhdong commented May 4, 2020

The change to the existing SparseApply* kernels which removes Ref(T) from the signature is backwards incompatible and can't be done. Adding new kernels for the hash apply is fine, though. I do wonder if we need the Optimizer method _apply_dense_hash or whether we can use a separate optimizer-like class which knows about the hash application. This has the advantage that it naturally covers the use cases where people want different optimizers for the really sparse embedding layers (which I think is relatively common).

On Mon, May 4, 2020 at 10:17 AM rhdong @.***> wrote: That's a very interesting proposal. From a high level view (and I'm probably wrong) it looks like it proposes a new type of variable and a new type of optimizer which can update that variable. Given that this is the case I think we can implement this in addons or some other SIG package as long as there are APIs in core TF to ensure that this variable can declare itself checkpointable, be tracked by something like tf.Module / keras.Model (so you can do model.trainable_sparse_variables), and maybe be automatically watched via the gradient tape. Can you expand the document to clarify the details of these changes to existing parts of TF as opposed to most of the content which is on the new types? Thanks! Thank Alex, In fact, My initial idea was to encapsulate a some kind of ResourceVariable backed Hashtable, as we know TF is not good at training any non tf.Variable. I reuse lookup.MutableHashTable because I don't like to write a new hash lib in TF , especially, lookup.XX support checkpointable and deployable on tf.distribute.Server. Here is the compare based on v1.15.2 shows that the range of core effected by the RFC: https://github.com/tensorflow/tensorflow/compare/v1.15.2...rhdong:rfc?expand=1 The main changes: 1. supporting the random initiallizer on lookup.MutableHashTable.Find 2. Four stateful optimizers(Adagrad, Adam, FTRL, Momentum) adaptation.(Maybe cancelled in new schema) Thanks! — You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub <#237 (comment)>, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAABHRMN3Q36LC7PKGV6URDRP32D7ANCNFSM4MQXNN6A .
-- - Alex

Yes, you're right, this is only a temp version. I have changed the name to _apply_dense_unstateful, XX_hash is a bad name.
About seperate optimizer class, I'm not sure which option would be better, I prefer to use the same optimizer to provide a consistent experience for algorithm engineers, because a model in deep learning RecSys may contain dense weights and sparse weights at the same time..

@yuefengz
Copy link
Contributor

yuefengz commented May 4, 2020

I think TensorFlow can provide a way to extend optimizers so that you can extend existing optimizers to handle your sparse weights.

@alextp
Copy link
Contributor

alextp commented May 4, 2020 via email

@byronyi
Copy link
Contributor

byronyi commented May 4, 2020

I think TensorFlow can provide a way to extend optimizers so that you can extend existing optimizers to handle your sparse weights.

cc @omalleyt12 who proposes the new customizable optimizer in #234. Mind to shed some light on this?

@rhdong
Copy link
Member Author

rhdong commented May 23, 2020

@yuefengz @byronyi @alextp @smilingday @facaiy @seanpmorgan @omalleyt12
Hi all,
I just commit an important update for optimizer reusing scheme based on ResourceVariable and come up API detailed design. And I will provide a runnable demo on docker.io as soon as possible.
Thank you.

@rhdong
Copy link
Member Author

rhdong commented May 23, 2020

I think this version scheme is simple and natural enough for core.

Copy link
Contributor

@byronyi byronyi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates. Several issues that may worth considering:

  1. These APIs might not worth a separate TF sub-package, i.e. tf.dynamic_embedding. How about, e.g., tf.nn.dynamic_embedding_lookup?
  2. Try to override methods in subclass and avoid modifying the base class if not necessary.
  3. trainable_wrap is a resource variable, but dynamic_embedding.Variable is not (any specific reason?). I confirmed with the author offline that the dynamic_embedding.Variable represents the whole embedding layer, while TrainableWrap wraps a single value lookup from embedding. Current naming doesn't reflect this semantics.

name='embedding_lookup',
max_norm=None):
"""Provides a dynamic version of embedding_lookup
similar to tf.nn.embedding_lookup.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if this deserves another top-level package name. How about tf.nn.dynamic_embedding_lookup?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if this deserves another top-level package name. How about tf.nn.dynamic_embedding_lookup?

Good idea!

###

@tf_export(v1=["dynamic_embedding.embedding_lookup_sparse"])
def embedding_lookup_sparse(params,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, how about tf.nn.dynamic_embedding_lookup_sparse?

Copy link
Member Author

@rhdong rhdong May 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

...
Yes , you are right, maybe the tf.nn or tf.keras would be better choices, Hi @yuefengz could you give us some advice?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to put this in a separate repo first. We can allow it to graduate into tf core if it has large number of users.

Copy link
Member Author

@rhdong rhdong May 28, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to put this in a separate repo first. We can allow it to graduate into tf core if it has large number of users.

Maybe we can merge with tf.nn.embedding_lookup for they has same input arguments.


##### Runtime random initialization

Since sparse training does not allocate memory on train-loops start, sparse weights cannot be initialized statically like we do on `tf.Variable`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK the value of tf.Variable is not initialized statically. Perhaps you mean the shape of the variable has to be static?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will check the code.
I mean here all values in tf.Variable are determined before starting training, while the values in hash tables will be unknown for the memory has not been allocated.

* tensorflow/core/kernels/lookup_table_op.cc

```cpp
Status MutableHashTableOfTensors::Find(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's implementation details, but considering function names like FindOrInsert to avoid confusion.

Copy link
Contributor

@yuefengz yuefengz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much for your RFC, Haidong! Some high level comments:

  1. I think this doc should focus on what changes needs to be made in TF core and probably mention a bit on how you are going to leverage the change.
  2. your proposal should be tf2-compatible. Some proposed changes, such as in optimizers, seem only applicable to tf1.
  3. I still have some confusion on the overall workflow. I recommend you implement a demo with minimal viable features and some simple unit tests in your own repo first. You can use monkey-patch to make changes to TF core.

initial_value=vals,
dtype=params.value_dtype,
trainable=params.trainable)
if max_norm is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why TrainableWrap is not used in this branch?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, it is bug, I will fix it.

###

@tf_export(v1=["dynamic_embedding.embedding_lookup_sparse"])
def embedding_lookup_sparse(params,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to put this in a separate repo first. We can allow it to graduate into tf core if it has large number of users.

slot_name,
op_name):
"""Helper function for creating a slot variable for statefull optimizers."""
_params_var, _params_ids = _IDS_OF_RESOURCEVARIABLE_MAPPER_.get(primary)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is _IDS_OF_RESOURCEVARIABLE_MAPPER_ ?

It looks like your slot variable is created for every step since it depends on the ids (while there is not such requirement in the original optimizers). Is that correct?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is _IDS_OF_RESOURCEVARIABLE_MAPPER_ ?

It looks like your slot variable is created for every step since it depends on the ids (while there is not such requirement in the original optimizers). Is that correct?

I will cancel the _IDS_OF_RESOURCEVARIABLE_MAPPER_ by reusing the Optimizer.slots


```python

class _DenseDynamicEmbeddingTrainableProcessor(_OptimizableVariable):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this class need to be different for different optimizers?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to be.


```python
@tf_export("dynamic_embedding.Variable")
class Variable(object):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably can use a different name instead of Variable?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably can use a different name instead of Variable?

Yeah, agree with you.

"""
pass

def remove(self, keys, name=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add some details for when the remove would be triggered?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, the users want to control the total number of features by some custom policy.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding some docs would help.


## Sparse Domain Isolation

### Overview of Design
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does the gradient look like? Do you need to define custom gradient function for embedding_lookup?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does the gradient look like? Do you need to define custom gradient function for embedding_lookup?

No need to define custom gradient function.
I believe the back-propagation should end after the updating to TrainableWrap.


```python

class TrainableWrap(ResourceVariable):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding of TrainableWrap is a container to hold (dynamic_embedding, ids) but pretends to be a tf.Variable. I am wondering whether it has to a ResourceVariable?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding of TrainableWrap is a container to hold (dynamic_embedding, ids) but pretends to be a tf.Variable. I am wondering whether it has to a ResourceVariable?

For all of optimizers can train ResourceVariable through resource_apply{,_sparse} Or you has to extend the them one-by-one.

@facaiy
Copy link
Member

facaiy commented Jun 3, 2020

Since this RFC targets for SIG AddOns, add SIG AddOns leads
@facaiy @seanpmorgan
and TF sponsor @karmel
as reviewers.

Thanks for ping me, @yuefengz @smilingday . The proposal is very interesting. I'm wondering if we can introduce a new kind of Variable class and reuse all existing optimizers (in tf-core or tf addons).

I'm afraid the proposal goes out of scope of tf-addons, so I suggest to put them in a separate repo first. @seanpmorgan Sean, what do you think?

@smilingday
Copy link
Contributor

Sean has discussed with SIG AddOns meetings and replied in seperate email threads that tf-addons might not be a good fit. We are still exploring the right place for those contributions.

@rhdong
Copy link
Member Author

rhdong commented Jun 4, 2020 via email

@levyfan
Copy link

levyfan commented Jun 17, 2020

Is this RFC related to the recently proposed paper "DynamicEmbedding: Extending TensorFlow for Colossal-Scale Applications" by Google? https://arxiv.org/pdf/2004.08366.pdf

@rhdong
Copy link
Member Author

rhdong commented Jun 24, 2020

Is this RFC related to the recently proposed paper "DynamicEmbedding: Extending TensorFlow for Colossal-Scale Applications" by Google? https://arxiv.org/pdf/2004.08366.pdf

No, this is a different scheme proposed in an earlier paper Distributed Equivalent Substitution Training for Large-Scale Recommender Systems(accepted by SIGIR'2020).

rhdong added a commit to rhdong/tensorflow that referenced this pull request Jul 14, 2020
@rhdong
Copy link
Member Author

rhdong commented Jul 14, 2020

@yuefengz @tanzhenyu @byronyi @alextp Hi, I just updated this RFC and this update contains some key features include the scheme of compatible with all tf.initializer without hacking too much on MutableHashTableOfTensors::Find and we also provided the our patch to core tensorflow/tensorflow#41371, please help us improve it, thank you!

@shenbaise
Copy link

is it compatible with tensorflow serving ? @rhdong

@rhdong
Copy link
Member Author

rhdong commented Jul 24, 2020

is it compatible with tensorflow serving ? @rhdong

Yes

@shenbaise
Copy link

shenbaise commented Jul 28, 2020

Hi @rhdong , I fix some bugs(shape of TrainableWrapper) and build tf 2.4.0, based on your code. It seems the dynamic_embedding didn't updated in training process.

Code as follows:

import tensorflow as tf
from tensorflow.keras.layers import Dense, Lambda
from tensorflow import dynamic_embedding as de
import numpy as np

idx = np.random.randint(0, 10, 100)
label = np.array([1.0 if a % 2 == 0 else 0.0 for a in idx], dtype=np.float32)

class MyModel(tf.keras.Model):
  def __init__(self):
    super(MyModel, self).__init__()
    self.w = de.get_variable(name="dynamic_embeddings", dim=8, initializer=np.random.random(8))
    self.d0 = Lambda(lambda x: de.embedding_lookup(params=self.w, ids=x, name="wide-sparse-weights"))
    self.d1 = Dense(10, activation='relu')
    self.d2 = Dense(1, activation='sigmoid')
    self.x0 = None
  def call(self, x):
    self.x0 = self.d0(x)
    x1 = self.d1(self.x0)
    return self.d2(x1)

model = MyModel()
loss_func = tf.keras.losses.BinaryCrossentropy()
optimizer = tf.keras.optimizers.Adagrad(learning_rate=.5)
train_loss = tf.keras.metrics.Mean(name='train_loss')

def train_step(x, label, print_loss=False):
    with tf.GradientTape() as tape:
      logits = model(x)
      loss = loss_func(logits, label)
    trainable_weights = model.trainable_variables
    # trainable_weights.append(model.x0)
    grads = tape.gradient(loss, trainable_weights)
    optimizer.apply_gradients(zip(grads, trainable_weights))
    if print_loss:
        print("loss:{}".format(train_loss(loss).numpy()))

def emb_sum():
    a = de.embedding_lookup(params=model.w, ids=np.array([2, 3]), name="wide-sparse-weights")
    return a.numpy().sum()

def kernel_sum():
    return model.d1.kernel.numpy().sum()

print("emb sum:{}".format(emb_sum()))
for i in range(20):
    train_step(idx.reshape(100, 1), label.reshape(100, 1))
print("emb sum:{}".format(emb_sum()))
print("kernel sum:{}".format(kernel_sum()))
# train more
for i in range(10):
    train_step(idx.reshape(100, 1), label.reshape(100, 1), print_loss=True)
print("emb sum:{}".format(emb_sum()))
print("kernel sum:{}".format(kernel_sum()))

# print trainable weights
print([v.name for v in model.trainable_weights])

console:

emb sum:**-0.031497083604335785**
emb sum:**-0.031497083604335785**
kernel sum:0.6821714043617249
loss:7.522636
loss:7.52227
loss:7.5219383
loss:7.521633
loss:7.521351
loss:7.521089
loss:7.520846
loss:7.5206184
loss:7.5204053
loss:7.5202055
emb sum:**-0.031497083604335785**
kernel sum:0.6808109283447266
['my_model/dense/kernel:0', 'my_model/dense/bias:0', 'my_model/dense_1/kernel:0', 
'my_model/dense_1/bias:0', 'my_model/lambda/TrainableWrapper:0']

@rhdong
Copy link
Member Author

rhdong commented Jul 28, 2020

Hi @rhdong , I fix some bugs(shape of TrainableWrapper) and build tf 2.4.0, based on your code. It seems the dynamic_embedding didn't updated in training process.

Code as follows:

import tensorflow as tf
from tensorflow.keras.layers import Dense, Lambda
from tensorflow import dynamic_embedding as de
import numpy as np

idx = np.random.randint(0, 10, 100)
label = np.array([1.0 if a % 2 == 0 else 0.0 for a in idx], dtype=np.float32)

class MyModel(tf.keras.Model):
  def __init__(self):
    super(MyModel, self).__init__()
    self.w = de.get_variable(name="dynamic_embeddings", dim=8, initializer=np.random.random(8))
    self.d0 = Lambda(lambda x: de.embedding_lookup(params=self.w, ids=x, name="wide-sparse-weights"))
    self.d1 = Dense(10, activation='relu')
    self.d2 = Dense(1, activation='sigmoid')
    self.x0 = None
  def call(self, x):
    self.x0 = self.d0(x)
    x1 = self.d1(self.x0)
    return self.d2(x1)

model = MyModel()
loss_func = tf.keras.losses.BinaryCrossentropy()
optimizer = tf.keras.optimizers.Adagrad(learning_rate=.5)
train_loss = tf.keras.metrics.Mean(name='train_loss')

def train_step(x, label, print_loss=False):
    with tf.GradientTape() as tape:
      logits = model(x)
      loss = loss_func(logits, label)
    trainable_weights = model.trainable_variables
    # trainable_weights.append(model.x0)
    grads = tape.gradient(loss, trainable_weights)
    optimizer.apply_gradients(zip(grads, trainable_weights))
    if print_loss:
        print("loss:{}".format(train_loss(loss).numpy()))

def emb_sum():
    a = de.embedding_lookup(params=model.w, ids=np.array([2, 3]), name="wide-sparse-weights")
    return a.numpy().sum()

def kernel_sum():
    return model.d1.kernel.numpy().sum()

print("emb sum:{}".format(emb_sum()))
for i in range(20):
    train_step(idx.reshape(100, 1), label.reshape(100, 1))
print("emb sum:{}".format(emb_sum()))
print("kernel sum:{}".format(kernel_sum()))
# train more
for i in range(10):
    train_step(idx.reshape(100, 1), label.reshape(100, 1), print_loss=True)
print("emb sum:{}".format(emb_sum()))
print("kernel sum:{}".format(kernel_sum()))

# print trainable weights
print([v.name for v in model.trainable_weights])

console:

emb sum:**-0.031497083604335785**
emb sum:**-0.031497083604335785**
kernel sum:0.6821714043617249
loss:7.522636
loss:7.52227
loss:7.5219383
loss:7.521633
loss:7.521351
loss:7.521089
loss:7.520846
loss:7.5206184
loss:7.5204053
loss:7.5202055
emb sum:**-0.031497083604335785**
kernel sum:0.6808109283447266
['my_model/dense/kernel:0', 'my_model/dense/bias:0', 'my_model/dense_1/kernel:0', 
'my_model/dense_1/bias:0', 'my_model/lambda/TrainableWrapper:0']

@shenbaise Thank you for feedback, I will check and fix it as soon as possible.

@rhdong
Copy link
Member Author

rhdong commented Jul 30, 2020

Hi @rhdong , I fix some bugs(shape of TrainableWrapper) and build tf 2.4.0, based on your code. It seems the dynamic_embedding didn't updated in training process.
Code as follows:

import tensorflow as tf
from tensorflow.keras.layers import Dense, Lambda
from tensorflow import dynamic_embedding as de
import numpy as np

idx = np.random.randint(0, 10, 100)
label = np.array([1.0 if a % 2 == 0 else 0.0 for a in idx], dtype=np.float32)

class MyModel(tf.keras.Model):
  def __init__(self):
    super(MyModel, self).__init__()
    self.w = de.get_variable(name="dynamic_embeddings", dim=8, initializer=np.random.random(8))
    self.d0 = Lambda(lambda x: de.embedding_lookup(params=self.w, ids=x, name="wide-sparse-weights"))
    self.d1 = Dense(10, activation='relu')
    self.d2 = Dense(1, activation='sigmoid')
    self.x0 = None
  def call(self, x):
    self.x0 = self.d0(x)
    x1 = self.d1(self.x0)
    return self.d2(x1)

model = MyModel()
loss_func = tf.keras.losses.BinaryCrossentropy()
optimizer = tf.keras.optimizers.Adagrad(learning_rate=.5)
train_loss = tf.keras.metrics.Mean(name='train_loss')

def train_step(x, label, print_loss=False):
    with tf.GradientTape() as tape:
      logits = model(x)
      loss = loss_func(logits, label)
    trainable_weights = model.trainable_variables
    # trainable_weights.append(model.x0)
    grads = tape.gradient(loss, trainable_weights)
    optimizer.apply_gradients(zip(grads, trainable_weights))
    if print_loss:
        print("loss:{}".format(train_loss(loss).numpy()))

def emb_sum():
    a = de.embedding_lookup(params=model.w, ids=np.array([2, 3]), name="wide-sparse-weights")
    return a.numpy().sum()

def kernel_sum():
    return model.d1.kernel.numpy().sum()

print("emb sum:{}".format(emb_sum()))
for i in range(20):
    train_step(idx.reshape(100, 1), label.reshape(100, 1))
print("emb sum:{}".format(emb_sum()))
print("kernel sum:{}".format(kernel_sum()))
# train more
for i in range(10):
    train_step(idx.reshape(100, 1), label.reshape(100, 1), print_loss=True)
print("emb sum:{}".format(emb_sum()))
print("kernel sum:{}".format(kernel_sum()))

# print trainable weights
print([v.name for v in model.trainable_weights])

console:

emb sum:**-0.031497083604335785**
emb sum:**-0.031497083604335785**
kernel sum:0.6821714043617249
loss:7.522636
loss:7.52227
loss:7.5219383
loss:7.521633
loss:7.521351
loss:7.521089
loss:7.520846
loss:7.5206184
loss:7.5204053
loss:7.5202055
emb sum:**-0.031497083604335785**
kernel sum:0.6808109283447266
['my_model/dense/kernel:0', 'my_model/dense/bias:0', 'my_model/dense_1/kernel:0', 
'my_model/dense_1/bias:0', 'my_model/lambda/TrainableWrapper:0']

@shenbaise Thank you for feedback, I will check and fix it as soon as possible.

Hi @shenbaise , the reason is that the commit is not compatible with keras, especially the optimizer v2, I need two days to fix it and add the UT cases, please wait a moment, Thank you!

@rhdong
Copy link
Member Author

rhdong commented Aug 11, 2020

Hi @rhdong , I fix some bugs(shape of TrainableWrapper) and build tf 2.4.0, based on your code. It seems the dynamic_embedding didn't updated in training process.

Code as follows:

import tensorflow as tf
from tensorflow.keras.layers import Dense, Lambda
from tensorflow import dynamic_embedding as de
import numpy as np

idx = np.random.randint(0, 10, 100)
label = np.array([1.0 if a % 2 == 0 else 0.0 for a in idx], dtype=np.float32)

class MyModel(tf.keras.Model):
  def __init__(self):
    super(MyModel, self).__init__()
    self.w = de.get_variable(name="dynamic_embeddings", dim=8, initializer=np.random.random(8))
    self.d0 = Lambda(lambda x: de.embedding_lookup(params=self.w, ids=x, name="wide-sparse-weights"))
    self.d1 = Dense(10, activation='relu')
    self.d2 = Dense(1, activation='sigmoid')
    self.x0 = None
  def call(self, x):
    self.x0 = self.d0(x)
    x1 = self.d1(self.x0)
    return self.d2(x1)

model = MyModel()
loss_func = tf.keras.losses.BinaryCrossentropy()
optimizer = tf.keras.optimizers.Adagrad(learning_rate=.5)
train_loss = tf.keras.metrics.Mean(name='train_loss')

def train_step(x, label, print_loss=False):
    with tf.GradientTape() as tape:
      logits = model(x)
      loss = loss_func(logits, label)
    trainable_weights = model.trainable_variables
    # trainable_weights.append(model.x0)
    grads = tape.gradient(loss, trainable_weights)
    optimizer.apply_gradients(zip(grads, trainable_weights))
    if print_loss:
        print("loss:{}".format(train_loss(loss).numpy()))

def emb_sum():
    a = de.embedding_lookup(params=model.w, ids=np.array([2, 3]), name="wide-sparse-weights")
    return a.numpy().sum()

def kernel_sum():
    return model.d1.kernel.numpy().sum()

print("emb sum:{}".format(emb_sum()))
for i in range(20):
    train_step(idx.reshape(100, 1), label.reshape(100, 1))
print("emb sum:{}".format(emb_sum()))
print("kernel sum:{}".format(kernel_sum()))
# train more
for i in range(10):
    train_step(idx.reshape(100, 1), label.reshape(100, 1), print_loss=True)
print("emb sum:{}".format(emb_sum()))
print("kernel sum:{}".format(kernel_sum()))

# print trainable weights
print([v.name for v in model.trainable_weights])

console:

emb sum:**-0.031497083604335785**
emb sum:**-0.031497083604335785**
kernel sum:0.6821714043617249
loss:7.522636
loss:7.52227
loss:7.5219383
loss:7.521633
loss:7.521351
loss:7.521089
loss:7.520846
loss:7.5206184
loss:7.5204053
loss:7.5202055
emb sum:**-0.031497083604335785**
kernel sum:0.6808109283447266
['my_model/dense/kernel:0', 'my_model/dense/bias:0', 'my_model/dense_1/kernel:0', 
'my_model/dense_1/bias:0', 'my_model/lambda/TrainableWrapper:0']

Hi @shenbaise, I fix the issue and the commit is here

@mdanatg
Copy link

mdanatg commented Aug 26, 2020

FYI, @kttian wrote a prototype for a differentiable hash map, roughly the equivalent of TensorList, as part of her internship project. Here's a colab that demonstrates direct gradient updates: https://colab.sandbox.google.com/drive/1hyFmriuq4Bz61_rxg2bfdE_jXHVfX8Rr?usp=sharing#scrollTo=8HDUUBEFAesC

There may be an opportunity to join efforts on a core implementation.

@alextp @saxenasaurabh @dynamicwebpaige

@rhdong
Copy link
Member Author

rhdong commented Sep 6, 2020

FYI, @kttian wrote a prototype for a differentiable hash map, roughly the equivalent of TensorList, as part of her internship project. Here's a colab that demonstrates direct gradient updates: https://colab.sandbox.google.com/drive/1hyFmriuq4Bz61_rxg2bfdE_jXHVfX8Rr?usp=sharing#scrollTo=8HDUUBEFAesC

There may be an opportunity to join efforts on a core implementation.

@alextp @saxenasaurabh @dynamicwebpaige

This is good job. But I think it is difficult to make the hash map trainable .

@mdanatg
Copy link

mdanatg commented Sep 8, 2020

FYI, @kttian wrote a prototype for a differentiable hash map, roughly the equivalent of TensorList, as part of her internship project. Here's a colab that demonstrates direct gradient updates: https://colab.sandbox.google.com/drive/1hyFmriuq4Bz61_rxg2bfdE_jXHVfX8Rr?usp=sharing#scrollTo=8HDUUBEFAesC
There may be an opportunity to join efforts on a core implementation.
@alextp @saxenasaurabh @dynamicwebpaige

This is good job. But I think it is difficult to make the hash map trainable .

It already is trainable (at least in the sense of trainable that I believe you're referring to).

rhdong added a commit to rhdong/tensorflow that referenced this pull request Sep 16, 2020
rhdong added a commit to rhdong/tensorflow that referenced this pull request Sep 16, 2020
rhdong added a commit to rhdong/tensorflow that referenced this pull request Oct 16, 2020
rhdong added a commit to rhdong/tensorflow that referenced this pull request Oct 22, 2020
rhdong added a commit to rhdong/tensorflow that referenced this pull request Oct 26, 2020
rhdong added a commit to rhdong/tensorflow that referenced this pull request Nov 19, 2020
rhdong added a commit to rhdong/tensorflow that referenced this pull request Dec 21, 2020
…up support full size dynamic default values.

This PR is one part of RFC:tensorflow/community#237
@ematejska
Copy link

@yuefengz Is this still in draft mode? What are the plans with this RFC?

xinan-jiang pushed a commit to xinan-jiang/tensorflow that referenced this pull request Jan 4, 2024

### Trainable Wrapper

In the early scheme, the `dynamic_embedding.Variable` will directly be trained by optimizers and **we had to extand all optimizers one-by-one**.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any examples of how to train with dynamic_embedding.Variable and update optimizers to accomodate that? Would love to learn more about this solution, esp how the optimizer should be updated to support the training.

cc: @rhdong @yuefengz @byronyi @alextp

Thanks.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @yeyinqcri , thanks for reaching out! You might be interested in this long-maintained repo, which was inspired by this RFC.: https://github.com/tensorflow/recommenders-addons

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the prompt response, yes, I am aware of TFRA and understood TFRA uses TrainableWrapper to make hash table trainable with default tensorflow runtime, however, we run into the issue when using TFRA + PS to train model in keras. This is understandable given PS is already on deprecation path. What I am trying to test is that if it is still possible to use the old way (extend optimizer) to make it work.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the prompt response, yes, I am aware of TFRA and understood TFRA uses TrainableWrapper to make hash table trainable with default tensorflow runtime, however, we run into the issue when using TFRA + PS to train model in keras. This is understandable given PS is already on deprecation path. What I am trying to test is that if it is still possible to use the old way (extend optimizer) to make it work.

Oh, I see. Did you submit an issue for it?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will submit an issue for it, in the meanwhile, we want to explore other solutions (e.g. implement optimizer) to unblock our urgent needs if possible, so appreciate if you can share how the old way works for us to follow.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! More details in an issue will help us assess it better.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.