Skip to content

Commit 2c03dff

Browse files
siju-samuelTrevor Morris
authored andcommitted
[KERAS]Embedding layer (apache#5444)
1 parent d631dce commit 2c03dff

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

python/tvm/relay/frontend/keras.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,14 @@ def _convert_permute(inexpr, keras_layer, _):
207207
return _op.transpose(inexpr, axes=(0,) + keras_layer.dims)
208208

209209

210+
def _convert_embedding(inexpr, keras_layer, etab):
211+
indices = inexpr
212+
weightList = keras_layer.get_weights()
213+
weight = etab.new_const(weightList[0])
214+
out = _op.take(weight, indices.astype('int32'), axis=0)
215+
216+
return out
217+
210218
def _convert_dense(inexpr, keras_layer, etab):
211219
weightList = keras_layer.get_weights()
212220
weight = etab.new_const(weightList[0].transpose([1, 0]))
@@ -893,7 +901,7 @@ def _default_skip(inexpr, keras_layer, _): # pylint: disable=unused-argument
893901
'Maximum' : _convert_merge,
894902
'Dot' : _convert_merge,
895903
'Permute' : _convert_permute,
896-
# 'Embedding' : _convert_embedding,
904+
'Embedding' : _convert_embedding,
897905
# 'RepeatVector' : _convert_repeat_vector,
898906

899907
'InputLayer' : _default_skip,

tests/python/frontend/keras/test_forward.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,24 @@ def test_forward_zero_padding3d(self, keras):
466466
keras_model = keras.models.Model(data, x)
467467
verify_keras_frontend(keras_model, layout='NDHWC')
468468

469+
470+
def test_forward_embedding(self, keras):
471+
data = keras.layers.Input(shape=(2, 4), dtype="int32")
472+
x = keras.layers.Embedding(10, 3)(data)
473+
keras_model = keras.models.Model(data, x)
474+
verify_keras_frontend(keras_model, need_transpose=False)
475+
476+
data = keras.layers.Input(shape=(2, 3, 4), dtype="int32")
477+
x = keras.layers.Embedding(4, 5)(data)
478+
keras_model = keras.models.Model(data, x)
479+
verify_keras_frontend(keras_model, need_transpose=False)
480+
481+
data = keras.layers.Input(shape=(6, 2, 3, 4), dtype="int32")
482+
x = keras.layers.Embedding(4, 5)(data)
483+
keras_model = keras.models.Model(data, x)
484+
verify_keras_frontend(keras_model, need_transpose=False)
485+
486+
469487
if __name__ == '__main__':
470488
for k in [keras, tf_keras]:
471489
sut = TestKeras()
@@ -497,4 +515,4 @@ def test_forward_zero_padding3d(self, keras):
497515
sut.test_forward_pool3d(keras=k)
498516
sut.test_forward_upsample3d(keras=k)
499517
sut.test_forward_zero_padding3d(keras=k)
500-
518+
sut.test_forward_embedding(keras=k)

0 commit comments

Comments
 (0)