0
点赞
收藏
分享

微信扫一扫

【tensorflow】tf.nn.embedding_lookup的使用

唯米天空 2022-01-17 阅读 39

tf.nn.embedding_lookup即在给定的范围内做映射

下面直接看例子

样例1

t = np.asarray([1,2,3,0])

params = tf.constant([10,20,30,40])
embedded_inputs = tf.nn.embedding_lookup(params, t)

with tf.Session() as sess:
    print(sess.run(embedded_inputs))

结果

[20 30 40 10]

可见结果即是按照t的顺序,将params中查找出来

样例2

t = np.asarray([1,2,3,0])

with tf.variable_scope('test', reuse=tf.AUTO_REUSE):
    t = tf.convert_to_tensor(t)
    embedding_table = tf.Variable([10,20,30,40],
        name='phone_embedding', shape=[4],
        dtype=tf.float32)
    embedded_inputs = tf.nn.embedding_lookup(embedding_table, t)

init_ops = [tf.global_variables_initializer(),
            tf.local_variables_initializer()]
with tf.Session() as sess:
    sess.run(init_ops)

    print(sess.run(t))
    print(sess.run(embedding_table))
    print(sess.run(embedded_inputs))

结果

[1 2 3 0]
[10. 20. 30. 40.]
[20. 30. 40. 10.]

样例3

索引值为矩阵

#Index t is a 2D array
t = np.asarray([1,2,3,0]).reshape([2,2])

with tf.variable_scope('test', reuse=tf.AUTO_REUSE):
    t = tf.convert_to_tensor(t)
    embedding_table = tf.Variable([10,20,30,40],
        name='phone_embedding', shape=[4],
        dtype=tf.float32)
    embedded_inputs = tf.nn.embedding_lookup(embedding_table, t)

init_ops = [tf.global_variables_initializer(),
            tf.local_variables_initializer()]
with tf.Session() as sess:
    sess.run(init_ops)

    print(sess.run(t))
    print(sess.run(embedding_table))
    print(sess.run(embedded_inputs))

结果

[[1 2]
 [3 0]]
[10. 20. 30. 40.]
[[20. 30.]
 [40. 10.]]

embedding_table不用做对应的shape转换,不然会抛出错误



举报

相关推荐

0 条评论