Skip to content

Commit

Permalink
Fix consistency with nn.embedding_lookup
Browse files Browse the repository at this point in the history
  • Loading branch information
shkarupa-alex committed Aug 8, 2024
1 parent 1830558 commit e922d0a
Showing 1 changed file with 26 additions and 15 deletions.
41 changes: 26 additions & 15 deletions tfmiss/nn/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,12 @@ def adaptive_embedding_lookup(
if param_p_dim is not None:
dim_0_sizes.append(param_p_dim)
else:
with ops.colocate_with(params[p]):
with embedding_ops._colocate_with(params[p]):
dim_0_sizes.append(tf.shape(params[p])[0])
dim_0_sizes = tf.stack(dim_0_sizes)
total_ids_capacity = tf.reduce_sum(dim_0_sizes)
total_ids_capacity = tf.cast(
tf.reduce_sum(dim_0_sizes), dtype=flat_ids.dtype
)

p_cumsum = tf.cumsum(tf.cast(dim_0_sizes, dtype=flat_ids.dtype))
assert_max_id = tf.debugging.assert_less(
Expand All @@ -116,9 +118,8 @@ def adaptive_embedding_lookup(
side="right",
)

# Cast partition assignments to int32 for use in dynamic_partition.
# There really should not be more than 2^32 partitions.
p_assignments = tf.cast(p_assignments, tf.int32)
# No need to cast partition assignments to int32, `tf.searchsorted`
# already did it. There really should not be more than 2^32 partitions.

# Partition list of ids based on assignments into np separate lists
p_intervals = tf.concat(([0], p_cumsum), 0)
Expand All @@ -133,18 +134,28 @@ def adaptive_embedding_lookup(
for p in range(np):
pids = gather_ids[p]
transform_fn = transforms[p]

result = tf.gather(params[p], pids)
if 0 == p:
result = transform_fn(result)
else:
with ops.colocate_with(partitioned_result[0]):
result = tf.identity(result)
result = transform_fn(result)
result = embedding_ops._clip(result, pids, max_norm)

with tf.device(None):
with embedding_ops._colocate_with(params[0]):
result = tf.gather(params[p], pids)
result = embedding_ops._clip(result, pids, max_norm)
result = transform_fn(result)
partitioned_result.append(result)

# TODO: check `AdaptiveEmbedding` CPU-GPU placements
# Before using: drop transform_fn in prev block
# transformed_result = []
# for p in range(np):
# transform_fn = transforms[p]
# if 0 == p:
# result = transform_fn(partitioned_result[0])
# else:
# with tf.device(None):
# with embedding_ops._colocate_with(partitioned_result[0]):
# result = tf.identity(partitioned_result[p])
# result = transform_fn(result)
# transformed_result.append(result)
# partitioned_result = transformed_result

# Stitch these back together
ret = data_flow_ops.parallel_dynamic_stitch(
p_indices, partitioned_result, name=name
Expand Down

0 comments on commit e922d0a

Please sign in to comment.