EmbeddingVariable是DeepRec中稀疏参数的表达,用户训练好模型,需要从ckpt 获取稀疏参数的信息,包含:稀疏特征,特征频次,以及特征版本。
EmbeddingVariable在存储为ckpt中,以tensor形式存储,假设图中EV的名称为a/b/c/embedding,则在ckpt 中的tensor为四个,即
- "a/b/c/embedding-keys" Tensor 存储EV的key,shape=[N],dtype=int64;
- "a/b/c/embedding-values" Tensor 存储EV的value,shape=[N, embedding_dim], dtype=float;
- "a/b/c/embedding-freqs" Tenosr 存储EV的频次,shape=[N],dtype=int64;
- "a/b/c/embedding-versions" Tensor 存储EV的最后一次更新的step数,shape=[N],dtype=int64
可以通过tensorflow的sdk,读取ckpt 的值从而获取到EV对应的信息 以上一组4个tensor的第一维shape相同,并且按序对应。设置了partition的EV,part 的tensor数目取决于设置的partitioner,该特征的稀疏参数为所有part的全部tensor集合。
import tensorflow as tf
a = tf.get_embedding_variable('a', embedding_dim=4)
b = tf.get_embedding_variable('b', embedding_dim=8, partitioner=tf.fixed_size_partitioner(4))
emb_a = tf.nn.embedding_lookup(a, tf.constant([0,1,2,3,4], dtype=tf.int64))
emb_b = tf.nn.embedding_lookup(b, tf.constant([5,6,7,8,9], dtype=tf.int64))
emb=tf.concat([emb_a, emb_b], axis=1)
loss=tf.reduce_sum(emb)
optimizer=tf.train.AdagradOptimizer(0.1)
train_op=optimizer.minimize(loss)
saver=tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run([emb, train_op]))
saver.save(sess, "./ckpt_test/ckpt/model")
import tensorflow as tf
from tensorflow.core.protobuf.meta_graph_pb2 import MetaGraphDef
from tensorflow.python.framework import meta_graph
ckpt_path='ckpt_test/ckpt'
path=ckpt_path+'/model.meta'
meta_graph = meta_graph.read_meta_graph_file(path)
ev_node_list=[]
for node in meta_graph.graph_def.node:
if node.op == 'KvVarHandleOp':
ev_node_list.append(node.name)
print("ev node list", ev_node_list)
# filter ev-slot
non_slot_ev_list=[]
for node in ev_node_list:
if "Adagrad" not in node:
non_slot_ev_list.append(node)
print("ev (exculde slot) node list", non_slot_ev_list)
for name in non_slot_ev_list:
print(name+'-keys', tf.train.load_variable(ckpt_path, name+'-keys'))
print(name+'-values', tf.train.load_variable(ckpt_path, name+'-values'))
print(name+'-freqs', tf.train.load_variable(ckpt_path, name+'-freqs'))
print(name+'-versions', tf.train.load_variable(ckpt_path, name+'-versions'))
root@i22b15440:/home/admin/chen.ding/gitlab/code/DeepRec# python ckpt_test/gen_ckpt.py
2022-04-18 08:56:37.462789: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 2499445000 Hz
2022-04-18 08:56:37.468578: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7fd5bb36fc60 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2022-04-18 08:56:37.469203: I tensorflow/compiler/xla/service/service.cc:176] StreamExecutor device (0): Host, Default Version
[array([[ 0.19091757, -1.2494173 , -1.1098509 , -0.88375354, 1.5314401 ,
0.05683817, -0.3698739 , 1.7533029 , -0.95090204, 0.597882 ,
-0.39382792, 1.2318059 ],
[ 0.19091757, -1.2494173 , -1.1098509 , -0.88375354, -0.83169323,
0.15894873, -0.66453475, 0.84301287, 1.125458 , 0.12537971,
0.7338474 , -0.02672509],
[ 0.19091757, -1.2494173 , -1.1098509 , -0.88375354, -0.29247162,
-1.1461716 , 1.1172409 , 1.9220417 , -1.2039331 , -1.248681 ,
-1.9431682 , -0.27165115],
[ 0.19091757, -1.2494173 , -1.1098509 , -0.88375354, -0.78701115,
-0.28825614, 1.1483766 , -0.18648145, 0.7928211 , -0.4237969 ,
-0.00831279, 0.68605185],
[ 0.19091757, -1.2494173 , -1.1098509 , -0.88375354, 0.5981304 ,
0.9115529 , 1.2290154 , 1.329322 , 0.81167835, 0.43949136,
0.4266789 , 0.5895692 ]], dtype=float32), None]
root@i22b15440:/home/admin/chen.ding/gitlab/code/DeepRec# ls ckpt_test/ckpt/
checkpoint model.data-00000-of-00001 model.index model.meta
root@i22b15440:/home/admin/chen.ding/gitlab/code/DeepRec# python ckpt_test/read_ckpt.py
ev node list ['a', 'b/part_0', 'b/part_1', 'b/part_2', 'b/part_3', 'a/Adagrad', 'b/part_0/Adagrad', 'b/part_1/Adagrad', 'b/part_2/Adagrad', 'b/part_3/Adagrad']
ev (exculde slot) node list ['a', 'b/part_0', 'b/part_1', 'b/part_2', 'b/part_3']
a-keys [0 1 2 3 4]
a-values [[ 0.19091757 -1.2494173 -1.1098509 -0.88375354]
[ 0.19091757 -1.2494173 -1.1098509 -0.88375354]
[ 0.19091757 -1.2494173 -1.1098509 -0.88375354]
[ 0.19091757 -1.2494173 -1.1098509 -0.88375354]
[ 0.19091757 -1.2494173 -1.1098509 -0.88375354]]
a-freqs []
a-versions []
b/part_0-keys [8]
b/part_0-values [[-0.8823574 -0.3836024 1.0530304 -0.28182772 0.69747484 -0.51914316
-0.10365905 0.5907056 ]]
b/part_0-freqs []
b/part_0-versions []
b/part_1-keys [5 9]
b/part_1-values [[ 1.4360939 -0.03850809 -0.46522018 1.6579567 -1.0462483 0.5025357
-0.4891742 1.1364597 ]
[ 0.50278413 0.81620663 1.1336691 1.2339758 0.7163321 0.3441451
0.33133262 0.49422294]]
b/part_1-freqs []
b/part_1-versions []
b/part_2-keys [6]
b/part_2-values [[-0.83169323 0.15894873 -0.66453475 0.84301287 1.125458 0.12537971
0.7338474 -0.02672509]]
b/part_2-freqs []
b/part_2-versions []
b/part_3-keys [7]
b/part_3-values [[-0.3878179 -1.2415178 1.0218947 1.8266954 -1.2992793 -1.3440272
-2.0385144 -0.36699742]]
b/part_3-freqs []
b/part_3-versions []